In [1]:
import torch                                        # root package
from torch.utils.data import Dataset, DataLoader    # dataset representation and loading
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim 
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [23]:
train = datasets.MNIST(root='./data/',train=True, download=True, transform=ToTensor())
test = datasets.MNIST(root='./data/',train=False, download=True, transform=ToTensor())

In [24]:
bs = 64
train_loader = DataLoader(dataset=train, batch_size=bs, shuffle=True, num_workers=1, pin_memory=True)
test_loader = DataLoader(dataset=test, batch_size=bs, shuffle=True, num_workers=1, pin_memory=True)

In [32]:
class Teacher(nn.Module):
    def __init__(self):
        super(Teacher, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 1200),
            nn.ReLU(),
            nn.Linear(1200, 1200),
            nn.ReLU(),
            nn.Linear(1200, 10)
        )
    def forward(self, x):
        return self.layers(x)

In [33]:
teacher_model = Teacher()
teacher_model = teacher_model.to(device)

In [34]:
optimizer = optim.Adam(params=teacher_model.parameters(), lr=0.1)
loss = torch.nn.CrossEntropyLoss()

In [35]:
def train(model, epochs):
    for epoch in range(epochs):
        model.train()
        total_loss = []
        accurate = 0
        total = 0
        for x,y in train_loader:
#             x = x.to(device)
#             x = x.view(-1, 784)
#             y = y.to(device)
            total += x.shape[0]
            y_hat = model(x.view(-1,784))
            _, pred_label = torch.max(y_hat.data, 1)
#             print(y.shape)
            accurate += torch.sum(pred_label==y)
            train_loss = loss(y_hat, y)
            total_loss.append(train_loss)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        if epoch%5 == 0:
            print("Loss on train set : {} and Accuracy : {}".format((sum(total_loss)/len(train_loader)).item(), (accurate/total).item()))
            with torch.no_grad():
                model.eval()
                accurate = 0
                total = 0
                total_loss = []
                for x,y in test_loader:
#                     x = x.to(device)
                    x = x.view(-1,784)                                               
#                     y = y.to(device)
                    total += x.shape[0]
                    y_test = model(x)
                    _, pred_label = torch.max(y_test.data, 1)
#                     print(pred_label)
                    accurate += torch.sum(torch.argmax(y_test)==y)
                    total_loss.append(loss(y_test,y))
                print("Loss on test set : {} and Accuracy : {}".format((sum(total_loss)/len(test_loader)).item(), (accurate/total).item()))

In [36]:
train(teacher_model,11)

Loss on train set : 10.845402717590332 and Accuracy : 0.18386666476726532
Loss on test set : 2.289581060409546 and Accuracy : 0.003499999875202775
Loss on train set : 2.3105294704437256 and Accuracy : 0.1039000004529953
Loss on test set : 2.306274652481079 and Accuracy : 0.11019999533891678
Loss on train set : 2.310647964477539 and Accuracy : 0.10421666502952576
Loss on test set : 2.310048818588257 and Accuracy : 0.10979999601840973
