In [12]:
import os, sys, configparser, logging, argparse
sys.path.append('/'.join(os.getcwd().split('/')[:-2]))

import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']='2'

import torch
import torch.nn as nn
import torch.optim as optim

from XAE.dataset import rmMNIST
from XAE.util import init_params

In [13]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

d = 64
embed_data = nn.Sequential(
    nn.Conv2d(1, d, kernel_size = 4, stride = 2, padding = 1, bias = False),
    nn.BatchNorm2d(d),
    nn.ReLU(True),

    nn.Conv2d(d, d, kernel_size = 4, padding = 'same', bias = False),
    nn.BatchNorm2d(d),
    nn.ReLU(True),

    nn.Conv2d(d, 2*d, kernel_size = 4, stride = 2, padding = 1, bias = False),
    nn.BatchNorm2d(2*d),
    nn.ReLU(True),

    nn.Conv2d(2*d, 2*d, kernel_size = 4, padding = 'same', bias = False),
    nn.BatchNorm2d(2*d),
    nn.ReLU(True),

    nn.Flatten(),
).to(device)

embed_condition = nn.Sequential(
    nn.Linear(49*2*d, d),
    nn.BatchNorm1d(d),
    nn.ReLU(True),

    nn.Linear(d, 10),

#     nn.Softmax(dim = 1)
).to(device)

init_params(embed_data)
init_params(embed_condition)

In [17]:
a, b = next(iter(train_generator))

torch.Size([100, 1, 28, 28])

In [4]:
train_data = rmMNIST('/home/reddragon/data/MNIST', train = True, label = True, aux = [[i for i in range(10)], []], portion = 0.5)
train_generator = torch.utils.data.DataLoader(train_data, 100, num_workers = 5, shuffle = True, pin_memory=True, drop_last=True)

In [20]:
epoch = 10
opt = optim.Adam(list(embed_data.parameters()) + list(embed_condition.parameters()), lr = 1e-3, betas = (0.9, 0.999))
crit = nn.CrossEntropyLoss()

running_loss = 0.0
for k in range(epoch):
    for i, (data, condition) in enumerate(train_generator):
        opt.zero_grad()
        x = data.to(device)
        y = condition.to(device)
        output = embed_condition(embed_data(x))
        loss = crit(output, y)
        loss.backward()
        opt.step()
        running_loss += loss.item()
        
    print('[%d] loss: %.3f' % (k + 1, running_loss / len(train_generator)))
    running_loss = 0.0

[1] loss: 0.828
[2] loss: 0.070
[3] loss: 0.034
[4] loss: 0.022
[5] loss: 0.015
[6] loss: 0.011
[7] loss: 0.006
[8] loss: 0.007
[9] loss: 0.004
[10] loss: 0.006


In [21]:
torch.save(embed_data.state_dict(), 'embed_data_weight.pt')
torch.save(embed_condition.state_dict(), 'embed_condition_weight.pt')

In [22]:
ec = nn.Sequential(
    nn.Linear(49*2*d, d),
    nn.BatchNorm1d(d),
    nn.ReLU(True),

    nn.Linear(d, 10),
    nn.Softmax(dim = 1),
).to(device)

In [23]:
ec.load_state_dict(torch.load('embed_condition_weight.pt'))

<All keys matched successfully>

In [27]:
[p for p in embed_condition.parameters()]

[Parameter containing:
 tensor([[-0.0272, -0.0072,  0.0420,  ...,  0.0368,  0.0173,  0.0251],
         [ 0.0459,  0.0209,  0.0736,  ...,  0.0055,  0.0183, -0.0106],
         [ 0.0047,  0.0078,  0.0322,  ..., -0.0063,  0.0082,  0.0087],
         ...,
         [ 0.0373,  0.0126, -0.0186,  ...,  0.0112, -0.0154, -0.0446],
         [ 0.0137, -0.0045, -0.0916,  ..., -0.0115, -0.0617, -0.0774],
         [-0.0646,  0.0219, -0.0645,  ..., -0.0240, -0.0071, -0.0429]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([0.1588, 0.1544, 0.1111, 0.1165, 0.1430, 0.1775, 0.1619, 0.0993, 0.1717,
         0.1233, 0.1011, 0.1165, 0.1227, 0.2013, 0.1421, 0.1401, 0.1215, 0.0986,
         0.1408, 0.1818, 0.1606, 0.1956, 0.1584, 0.1229, 0.1178, 0.1299, 0.1662,
         0.1256, 0.1422, 0.1858, 0.1489, 0.1370, 0.1074, 0.1193, 0.1210, 0.1699,
         0.1548, 0.1660, 0.1475, 0.1372, 0.1172, 0.1652, 0.1245, 0.1620, 0.1450,
         0.0711, 0.1115, 0.1566, 0.1462, 0.0948, 0.1572, 0.131

In [26]:
[p for p in ec.parameters()]

[Parameter containing:
 tensor([[-0.0272, -0.0072,  0.0420,  ...,  0.0368,  0.0173,  0.0251],
         [ 0.0459,  0.0209,  0.0736,  ...,  0.0055,  0.0183, -0.0106],
         [ 0.0047,  0.0078,  0.0322,  ..., -0.0063,  0.0082,  0.0087],
         ...,
         [ 0.0373,  0.0126, -0.0186,  ...,  0.0112, -0.0154, -0.0446],
         [ 0.0137, -0.0045, -0.0916,  ..., -0.0115, -0.0617, -0.0774],
         [-0.0646,  0.0219, -0.0645,  ..., -0.0240, -0.0071, -0.0429]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([0.1588, 0.1544, 0.1111, 0.1165, 0.1430, 0.1775, 0.1619, 0.0993, 0.1717,
         0.1233, 0.1011, 0.1165, 0.1227, 0.2013, 0.1421, 0.1401, 0.1215, 0.0986,
         0.1408, 0.1818, 0.1606, 0.1956, 0.1584, 0.1229, 0.1178, 0.1299, 0.1662,
         0.1256, 0.1422, 0.1858, 0.1489, 0.1370, 0.1074, 0.1193, 0.1210, 0.1699,
         0.1548, 0.1660, 0.1475, 0.1372, 0.1172, 0.1652, 0.1245, 0.1620, 0.1450,
         0.0711, 0.1115, 0.1566, 0.1462, 0.0948, 0.1572, 0.131