In [1]:
!pip install ipynb

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import ipynb

In [2]:
import torch
import torch.nn as nn
import torch.functional as F 
from tqdm import tqdm
from ipynb.fs.full.mscred import MSCRED
from ipynb.fs.full.data import load_data
import matplotlib.pyplot as plt
import numpy as np
import os

def train(dataLoader, model, optimizer, epochs, device):
    model = model.to(device)
    print("------training on {}-------".format(device))
    for epoch in range(epochs):
        train_l_sum,n = 0.0, 0
        for x in tqdm(dataLoader):
            x = x.to(device)
            x = x.squeeze()
            #print(type(x))
            l = torch.mean((model(x)-x[-1].unsqueeze(0))**2)
            train_l_sum += l
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            n += 1
            #print("[Epoch %d/%d][Batch %d/%d] [loss: %f]" % (epoch+1, epochs, n, len(dataLoader), l.item()))
            
        print("[Epoch %d/%d] [loss: %f]" % (epoch+1, epochs, train_l_sum/n))

def test(dataLoader, model):
    print("------Testing-------")
    index = 800
    loss_list = []
    reconstructed_data_path = "./data/matrix_data"
    with torch.no_grad():
        for x in dataLoader:
            x = x.to(device)
            x = x.squeeze()
            reconstructed_matrix = model(x) 
            path_temp = os.path.join(reconstructed_data_path, 'reconstructed_data_' + str(index) + ".npy")
            print(path_temp)
            np.save(path_temp, reconstructed_matrix.cpu().detach().numpy())
            # l = criterion(reconstructed_matrix, x[-1].unsqueeze(0)).mean()
            # loss_list.append(l)
            # print("[test_index %d] [loss: %f]" % (index, l.item()))
            index += 1


if __name__ == '__main__':
    device = torch.device("cuda" if not torch.cuda.is_available() else "cpu")
    print("device is", device)
    dataLoader = load_data()
    mscred = MSCRED(3, 256)

    
    # mscred.load_state_dict(torch.load("./checkpoints/model1.pth"))
    optimizer = torch.optim.Adam(mscred.parameters(), lr = 0.0002)
    train(dataLoader["train"], mscred, optimizer, 10,device)
    print("saving the model....")
    torch.save(mscred.state_dict(), "./data/model2.pth")

    
    mscred.load_state_dict(torch.load("./data/model2.pth"))
    mscred.to(device)
    test(dataLoader["test"], mscred)


device is cpu
------training on cpu-------


  "    attention_w = []\n",
100%|█████████████████████████████████████████| 789/789 [02:21<00:00,  5.58it/s]


[Epoch 1/10] [loss: 0.001221]


100%|█████████████████████████████████████████| 789/789 [02:18<00:00,  5.70it/s]


[Epoch 2/10] [loss: 0.000147]


100%|█████████████████████████████████████████| 789/789 [02:17<00:00,  5.73it/s]


[Epoch 3/10] [loss: 0.000106]


100%|█████████████████████████████████████████| 789/789 [02:16<00:00,  5.77it/s]


[Epoch 4/10] [loss: 0.000083]


100%|█████████████████████████████████████████| 789/789 [02:17<00:00,  5.72it/s]


[Epoch 5/10] [loss: 0.000066]


100%|█████████████████████████████████████████| 789/789 [02:17<00:00,  5.74it/s]


[Epoch 6/10] [loss: 0.000057]


100%|█████████████████████████████████████████| 789/789 [02:17<00:00,  5.72it/s]


[Epoch 7/10] [loss: 0.000054]


100%|█████████████████████████████████████████| 789/789 [02:17<00:00,  5.74it/s]


[Epoch 8/10] [loss: 0.000047]


100%|█████████████████████████████████████████| 789/789 [02:17<00:00,  5.73it/s]


[Epoch 9/10] [loss: 0.000040]


100%|█████████████████████████████████████████| 789/789 [02:16<00:00,  5.77it/s]


[Epoch 10/10] [loss: 0.000039]
saving the model....
------Testing-------
./data/matrix_data/reconstructed_data_800.npy
./data/matrix_data/reconstructed_data_801.npy
./data/matrix_data/reconstructed_data_802.npy
./data/matrix_data/reconstructed_data_803.npy
./data/matrix_data/reconstructed_data_804.npy
./data/matrix_data/reconstructed_data_805.npy
./data/matrix_data/reconstructed_data_806.npy
./data/matrix_data/reconstructed_data_807.npy
./data/matrix_data/reconstructed_data_808.npy
./data/matrix_data/reconstructed_data_809.npy
./data/matrix_data/reconstructed_data_810.npy
./data/matrix_data/reconstructed_data_811.npy
./data/matrix_data/reconstructed_data_812.npy
./data/matrix_data/reconstructed_data_813.npy
./data/matrix_data/reconstructed_data_814.npy
./data/matrix_data/reconstructed_data_815.npy
./data/matrix_data/reconstructed_data_816.npy
./data/matrix_data/reconstructed_data_817.npy
./data/matrix_data/reconstructed_data_818.npy
./data/matrix_data/reconstructed_data_819.npy
./data/

./data/matrix_data/reconstructed_data_977.npy
./data/matrix_data/reconstructed_data_978.npy
./data/matrix_data/reconstructed_data_979.npy
./data/matrix_data/reconstructed_data_980.npy
./data/matrix_data/reconstructed_data_981.npy
./data/matrix_data/reconstructed_data_982.npy
./data/matrix_data/reconstructed_data_983.npy
./data/matrix_data/reconstructed_data_984.npy
./data/matrix_data/reconstructed_data_985.npy
./data/matrix_data/reconstructed_data_986.npy
./data/matrix_data/reconstructed_data_987.npy
./data/matrix_data/reconstructed_data_988.npy
./data/matrix_data/reconstructed_data_989.npy
./data/matrix_data/reconstructed_data_990.npy
./data/matrix_data/reconstructed_data_991.npy
./data/matrix_data/reconstructed_data_992.npy
./data/matrix_data/reconstructed_data_993.npy
./data/matrix_data/reconstructed_data_994.npy
./data/matrix_data/reconstructed_data_995.npy
./data/matrix_data/reconstructed_data_996.npy
./data/matrix_data/reconstructed_data_997.npy
./data/matrix_data/reconstructed_d

./data/matrix_data/reconstructed_data_1153.npy
./data/matrix_data/reconstructed_data_1154.npy
./data/matrix_data/reconstructed_data_1155.npy
./data/matrix_data/reconstructed_data_1156.npy
./data/matrix_data/reconstructed_data_1157.npy
./data/matrix_data/reconstructed_data_1158.npy
./data/matrix_data/reconstructed_data_1159.npy
./data/matrix_data/reconstructed_data_1160.npy
./data/matrix_data/reconstructed_data_1161.npy
./data/matrix_data/reconstructed_data_1162.npy
./data/matrix_data/reconstructed_data_1163.npy
./data/matrix_data/reconstructed_data_1164.npy
./data/matrix_data/reconstructed_data_1165.npy
./data/matrix_data/reconstructed_data_1166.npy
./data/matrix_data/reconstructed_data_1167.npy
./data/matrix_data/reconstructed_data_1168.npy
./data/matrix_data/reconstructed_data_1169.npy
./data/matrix_data/reconstructed_data_1170.npy
./data/matrix_data/reconstructed_data_1171.npy
./data/matrix_data/reconstructed_data_1172.npy
./data/matrix_data/reconstructed_data_1173.npy
./data/matrix

./data/matrix_data/reconstructed_data_1329.npy
./data/matrix_data/reconstructed_data_1330.npy
./data/matrix_data/reconstructed_data_1331.npy
./data/matrix_data/reconstructed_data_1332.npy
./data/matrix_data/reconstructed_data_1333.npy
./data/matrix_data/reconstructed_data_1334.npy
./data/matrix_data/reconstructed_data_1335.npy
./data/matrix_data/reconstructed_data_1336.npy
./data/matrix_data/reconstructed_data_1337.npy
./data/matrix_data/reconstructed_data_1338.npy
./data/matrix_data/reconstructed_data_1339.npy
./data/matrix_data/reconstructed_data_1340.npy
./data/matrix_data/reconstructed_data_1341.npy
./data/matrix_data/reconstructed_data_1342.npy
./data/matrix_data/reconstructed_data_1343.npy
./data/matrix_data/reconstructed_data_1344.npy
./data/matrix_data/reconstructed_data_1345.npy
./data/matrix_data/reconstructed_data_1346.npy
./data/matrix_data/reconstructed_data_1347.npy
./data/matrix_data/reconstructed_data_1348.npy
./data/matrix_data/reconstructed_data_1349.npy
./data/matrix

./data/matrix_data/reconstructed_data_1505.npy
./data/matrix_data/reconstructed_data_1506.npy
./data/matrix_data/reconstructed_data_1507.npy
./data/matrix_data/reconstructed_data_1508.npy
./data/matrix_data/reconstructed_data_1509.npy
./data/matrix_data/reconstructed_data_1510.npy
./data/matrix_data/reconstructed_data_1511.npy
./data/matrix_data/reconstructed_data_1512.npy
./data/matrix_data/reconstructed_data_1513.npy
./data/matrix_data/reconstructed_data_1514.npy
./data/matrix_data/reconstructed_data_1515.npy
./data/matrix_data/reconstructed_data_1516.npy
./data/matrix_data/reconstructed_data_1517.npy
./data/matrix_data/reconstructed_data_1518.npy
./data/matrix_data/reconstructed_data_1519.npy
./data/matrix_data/reconstructed_data_1520.npy
./data/matrix_data/reconstructed_data_1521.npy
./data/matrix_data/reconstructed_data_1522.npy
./data/matrix_data/reconstructed_data_1523.npy
./data/matrix_data/reconstructed_data_1524.npy
./data/matrix_data/reconstructed_data_1525.npy
./data/matrix

./data/matrix_data/reconstructed_data_1681.npy
./data/matrix_data/reconstructed_data_1682.npy
./data/matrix_data/reconstructed_data_1683.npy
./data/matrix_data/reconstructed_data_1684.npy
./data/matrix_data/reconstructed_data_1685.npy
./data/matrix_data/reconstructed_data_1686.npy
./data/matrix_data/reconstructed_data_1687.npy
./data/matrix_data/reconstructed_data_1688.npy
./data/matrix_data/reconstructed_data_1689.npy
./data/matrix_data/reconstructed_data_1690.npy
./data/matrix_data/reconstructed_data_1691.npy
./data/matrix_data/reconstructed_data_1692.npy
./data/matrix_data/reconstructed_data_1693.npy
./data/matrix_data/reconstructed_data_1694.npy
./data/matrix_data/reconstructed_data_1695.npy
./data/matrix_data/reconstructed_data_1696.npy
./data/matrix_data/reconstructed_data_1697.npy
./data/matrix_data/reconstructed_data_1698.npy
./data/matrix_data/reconstructed_data_1699.npy
./data/matrix_data/reconstructed_data_1700.npy
./data/matrix_data/reconstructed_data_1701.npy
./data/matrix

./data/matrix_data/reconstructed_data_1857.npy
./data/matrix_data/reconstructed_data_1858.npy
./data/matrix_data/reconstructed_data_1859.npy
./data/matrix_data/reconstructed_data_1860.npy
./data/matrix_data/reconstructed_data_1861.npy
./data/matrix_data/reconstructed_data_1862.npy
./data/matrix_data/reconstructed_data_1863.npy
./data/matrix_data/reconstructed_data_1864.npy
./data/matrix_data/reconstructed_data_1865.npy
./data/matrix_data/reconstructed_data_1866.npy
./data/matrix_data/reconstructed_data_1867.npy
./data/matrix_data/reconstructed_data_1868.npy
./data/matrix_data/reconstructed_data_1869.npy
./data/matrix_data/reconstructed_data_1870.npy
./data/matrix_data/reconstructed_data_1871.npy
./data/matrix_data/reconstructed_data_1872.npy
./data/matrix_data/reconstructed_data_1873.npy
./data/matrix_data/reconstructed_data_1874.npy
./data/matrix_data/reconstructed_data_1875.npy
./data/matrix_data/reconstructed_data_1876.npy
./data/matrix_data/reconstructed_data_1877.npy
./data/matrix

In [12]:
pwd

'/home/ranjith/thesis/Pytorch-MSCRED'