In [1]:
import argparse
import logging
import os
import random
import sys
import time
import json
from typing_extensions import Required
import numpy as np
import copy
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import SubsetRandomSampler
import torch
from torch import nn

In [2]:
from dataloader import shakespeare_dataloaders
from model import LSTM_shakespeare_1L

In [3]:
device = torch.device("cuda:0")
batch_size = 4
LR = 0.002

epoch = 10
clients = 10


In [4]:
def train():
    model.to(device)
    model.train()
    
    batch_loss = []
    num_correct = 0
    lstm_state = model.zero_state(batch_size=batch_size, device=device)
    
    count = 0
    for data, target in train_data:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output, lstm_state = model(data, lstm_state)
        loss = criterion(output, target)
        loss.backward()
        batch_loss.append(loss.item())
        optimizer.step()
        
        _, predicted = torch.max(output, -1)
        correct = predicted.eq(target).sum()
        num_correct += correct.item()
        
        count+=1
        if count %1000==0:
            print("Train ({}/{})".format(count*batch_size, len(train_data.dataset)))
        
    epoch_loss = sum(batch_loss) / len(batch_loss)
    train_acc = num_correct/len(train_data.dataset)
    
    return train_acc, epoch_loss

In [5]:
def test():
    model.to(device)
    model.eval()
    
    batch_loss = []
    num_correct = 0

    criterion = nn.CrossEntropyLoss().to(device)

    with torch.no_grad():
        lstm_state = model.zero_state(batch_size=batch_size, device=device)
        for data, target in test_data:

            data = data.to(device)
            target = target.to(device)
            pred, _ = model(data, lstm_state)
            loss = criterion(pred, target)
            batch_loss.append(loss.item())

            _, predicted = torch.max(pred, -1)
            correct = predicted.eq(target).sum()

            num_correct += correct.item()
            
    epoch_loss = sum(batch_loss) / len(batch_loss)
    test_acc = num_correct/len(test_data.dataset)

    return test_acc, epoch_loss

In [6]:
dataset = shakespeare_dataloaders(root="./shakespeare", 
                                  batch_size=batch_size, 
                                  clients=clients)
train_data = dataset[2]
test_data = dataset[3]

model = LSTM_shakespeare_1L()

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss().to(device)

In [8]:
for i in range(epoch):
    print("##### Round: {} #####".format(i))
    
    train_acc, train_loss = train()
    print("==> Train acc:{:.4f}, loss:{:.6f}".format(train_acc, train_loss))
    
    test_acc, test_loss = test()
    print("==> Test acc:{:.4f}, loss:{:.6f}".format(test_acc, test_loss))

##### Round: 0 #####
Train (4000/43608)
Train (8000/43608)
Train (12000/43608)
Train (16000/43608)
Train (20000/43608)
Train (24000/43608)
Train (28000/43608)
Train (32000/43608)
Train (36000/43608)
Train (40000/43608)
==> Train acc:0.3828, loss:2.257442
==> Test acc:0.4123, loss:2.202125
##### Round: 1 #####
Train (4000/43608)
Train (8000/43608)
Train (12000/43608)
Train (16000/43608)
Train (20000/43608)
Train (24000/43608)
Train (28000/43608)
Train (32000/43608)
Train (36000/43608)
Train (40000/43608)
==> Train acc:0.4329, loss:2.038142
==> Test acc:0.4419, loss:2.108703
##### Round: 2 #####
Train (4000/43608)
Train (8000/43608)
Train (12000/43608)
Train (16000/43608)
Train (20000/43608)
Train (24000/43608)
Train (28000/43608)
Train (32000/43608)
Train (36000/43608)
Train (40000/43608)
==> Train acc:0.4615, loss:1.925741
==> Test acc:0.4607, loss:2.064361
##### Round: 3 #####
Train (4000/43608)
Train (8000/43608)
Train (12000/43608)
Train (16000/43608)
Train (20000/43608)
Train (2400