In [None]:
%pip install wandb
%pip install graphviz
%pip install torchviz
import wandb
wandb.login()#doesnt detect WANDB_NOTEBOOK_NAME on windows

In [None]:
wandb.init(project="test-project", entity="simclr-doctoral-research")

In [None]:
wandb.config = {
  "learning_rate": 0.001,
  "epochs": 10,
  "batch_size": 512
}

In [None]:
import os
import wandb
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')


def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x


num_epochs = 10
batch_size = 512
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset_train = MNIST('./data', transform=img_transform, download=True,train = True)
dataset_test = MNIST('./data', transform=img_transform, download=True,train = False)

dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)


class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(True),
            nn.Linear(256, 128))
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256,28 * 28), 
            nn.Tanh())
    def forward(self, x,only_encode=False):
        if only_encode:
            return self.encoder(x)
        x = self.encoder(x)
        x = self.decoder(x)
        return x





In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=1e-5)#this line is needed
#to freeze gradients

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img).to(device)
        # ===================forward=====================
        output = model(img)
        loss = criterion(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item()))
    if epoch % 10 == 0:
        pic = to_img(output.cpu().data)
        save_image(pic, './mlp_img/image_{}.png'.format(epoch))
    wandb.log({"loss": loss})

    wandb.watch(model)
    #print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.data[0]))

pic = to_img(output.cpu().data)
save_image(pic, './mlp_img/image_final.png')


In [None]:
#test of the model
test_data = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

for data in test_data:
    img, _ = data
    img = img.view(img.size(0), -1)
    img = Variable(img).to(device)
    output = model(img)
    pic = to_img(output.cpu().data)
    save_image(pic, './mlp_img/image_test.png')
#accuracy of the autoencoder comparing input and output
loss_sum = 0
for data in test_data:
    img, _ = data
    img = img.view(img.size(0), -1)
    img = Variable(img).to(device)
    output = model(img)
    loss = criterion(output, img)
    loss_sum += loss.item()

print(loss_sum/len(test_data))


In [None]:
from torchviz import make_dot
y = model(img)
make_dot(y, params=dict(list(model.named_parameters()))).render("torchviz", format="png")

Now the encoder is going to be frozen

In [119]:
#freeze the encoder
# print(model.state_dict())

model.encoder[0].weight.requires_grad = False
model.encoder[0].bias.requires_grad = False
model.encoder[2].weight.requires_grad = False
model.encoder[2].bias.requires_grad = False


for name, param in model.named_parameters():
    print(name, param.requires_grad)



# if param.requires_grad:print(name)
# for param in model.parameters():
#     print(param)
#     param.requires_grad = False


encoder.0.weight False
encoder.0.bias False
encoder.2.weight False
encoder.2.bias False
decoder.0.weight True
decoder.0.bias True
decoder.2.weight True
decoder.2.bias True


In [120]:
#save the encoder
os.makedirs('./saved_models', exist_ok=True)
torch.save(model.state_dict(), './saved_models/autoencoder.pth')

In [121]:
class linear_classifier(nn.Module):
    def __init__(self):
        super(linear_classifier, self).__init__()
        self.fc1 = nn.Linear(128, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        

In [122]:
#load the encoder
model.load_state_dict(torch.load('./saved_models/autoencoder.pth'))
#add a mlp to the encoder
model.add_module('linear_classifier', linear_classifier())



In [126]:
#joined model
class joined_model(nn.Module):
    def __init__(self):
        super(joined_model, self).__init__()
        self.encoder = autoencoder.encoder
        self.classifier = linear_classifier()
    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x
joined_model = joined_model().to(device)

In [127]:
#verify weight are frozen

for name, param in joined_model.named_parameters():
    print(name, param.requires_grad)
print(model.parameters)


encoder.0.weight False
encoder.0.bias False
encoder.2.weight False
encoder.2.bias False
classifier.fc1.weight True
classifier.fc1.bias True
classifier.fc2.weight True
classifier.fc2.bias True
<bound method Module.parameters of autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=784, bias=True)
    (3): Tanh()
  )
  (linear_classifier): linear_classifier(
    (fc1): Linear(in_features=128, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)>


In [129]:
#train the model on the labeled data
# model = model.to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, joined_model.parameters()), lr=learning_rate, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
    for data in dataloader:
        img, label = data
        img = img.view(img.size(0), -1)
        img = Variable(img).to(device)
        label = Variable(label).to(device)
        # ===================forward=====================
        output = joined_model(img)#This should only use the encoder
        loss = criterion(output, label)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}'
            .format(epoch + 1, num_epochs, loss.item()))
    wandb.log({"loss": loss})
    wandb.watch(joined_model)

epoch [1/10], loss:0.2708
epoch [2/10], loss:0.3691
epoch [3/10], loss:0.4757
epoch [4/10], loss:0.4167
epoch [5/10], loss:0.3336
epoch [6/10], loss:0.3332
epoch [7/10], loss:0.4741
epoch [8/10], loss:0.2894
epoch [9/10], loss:0.2555
epoch [10/10], loss:0.2197


In [136]:
#evaluate the model
loss_sum = 0
test_data = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
for data in test_data:
    img, label = data
    img = img.view(img.size(0), -1)
    img = Variable(img).to(device)
    label = Variable(label).to(device)
    output = joined_model(img)
    loss = criterion(output, label)
    loss_sum += loss.item()
print(loss_sum/len(test_data))


0.30486280769109725


In [140]:
import matplotlib.pyplot as plt
#show results of the model
test_data = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
for data in test_data:
    img, label = data
    img = img.view(img.size(0), -1)
    img = Variable(img).to(device)
    label = Variable(label).to(device)
    output = joined_model(img)
    pic = to_img(img.cpu().data)
    # save_image(pic, './mlp_img/image_test.png')
    #show predicted label for the image
    print(output.argmax(dim=1))
    #show the true label for the image
    print(label)
    #show the image
    plt.imshow(img.cpu().data[0].numpy().transpose(1, 2, 0))
    plt.show()
    plt.close()

    

tensor([5, 0, 6, 5, 0, 1, 4, 9, 2, 1, 2, 7, 3, 9, 8, 5, 4, 3, 3, 1, 2, 3, 0, 3,
        3, 3, 7, 4, 5, 9, 1, 8, 7, 1, 6, 4, 4, 6, 9, 5, 9, 7, 1, 0, 7, 9, 3, 3,
        8, 2, 8, 0, 1, 4, 9, 0, 4, 2, 7, 4, 6, 4, 6, 7, 8, 3, 1, 4, 8, 0, 2, 1,
        8, 0, 8, 3, 2, 8, 7, 5, 0, 6, 6, 5, 1, 5, 2, 7, 5, 5, 4, 5, 7, 6, 2, 8,
        3, 0, 4, 9, 5, 0, 2, 8, 6, 6, 8, 3, 8, 8, 8, 0, 0, 7, 2, 6, 2, 4, 4, 7,
        5, 4, 1, 0, 8, 2, 7, 6, 8, 1, 3, 6, 9, 6, 0, 2, 4, 9, 7, 5, 1, 8, 0, 7,
        1, 1, 6, 4, 4, 1, 2, 5, 7, 4, 1, 7, 2, 2, 2, 9, 3, 5, 3, 1, 7, 0, 0, 0,
        7, 8, 2, 7, 1, 9, 5, 3, 5, 2, 1, 6, 3, 2, 3, 7, 3, 1, 6, 0, 0, 5, 7, 1,
        8, 8, 5, 4, 8, 5, 7, 7, 9, 9, 2, 3, 3, 9, 3, 8, 3, 7, 3, 6, 3, 6, 0, 7,
        1, 5, 0, 8, 8, 3, 8, 5, 1, 1, 8, 1, 4, 3, 2, 3, 7, 9, 1, 3, 0, 7, 4, 1,
        5, 3, 3, 1, 6, 2, 3, 5, 8, 6, 6, 1, 9, 8, 4, 2, 8, 1, 7, 7, 9, 5, 3, 4,
        7, 1, 8, 6, 0, 9, 6, 1, 8, 6, 2, 1, 9, 3, 2, 2, 3, 3, 8, 7, 4, 2, 7, 1,
        8, 2, 7, 3, 1, 1, 5, 3, 2, 1, 0,

ValueError: axes don't match array

In [130]:
#save the joined_model
torch.save(joined_model.state_dict(), './saved_models/joined_model.pth')
