In [65]:
import torchvision
from torchvision.transforms import ToTensor


In [66]:
import torchvision
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import cv2
import numpy as np
import einops

In [67]:
class DDPM():
    def __init__(self,
                 device,
                 n_steps: int,
                 min_beta: float = 0.0001,
                 max_beta: float = 0.02):
        betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
        alphas = 1 - betas
        alpha_bars = torch.empty_like(alphas)
        product = 1
        for i, alpha in enumerate(alphas):
            product *= alpha
            alpha_bars[i] = product
        self.betas = betas
        self.n_steps = n_steps
        self.alphas = alphas
        self.alpha_bars = alpha_bars
    def sample_forward(self, x, t, eps=None):
        alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
        if eps is None:
            eps = torch.randn_like(x)
        res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
        return res
    def sample_backward_step(self, x_t, t, net, simple_var=True):
        n = x_t.shape[0]
        t_tensor = torch.tensor([t] * n,
                                dtype=torch.long).to(x_t.device).unsqueeze(1)
        eps, _ = net(x_t, t_tensor)

        if t == 0:
            noise = 0
        else:
            if simple_var:
                var = self.betas[t]
            else:
                var = (1 - self.alpha_bars[t - 1]) / (
                    1 - self.alpha_bars[t]) * self.betas[t]
            noise = torch.randn_like(x_t)
            noise *= torch.sqrt(var)

        mean = (x_t -
                (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
                eps) / torch.sqrt(self.alphas[t])
        x_t = mean + noise

        return x_t
    def sample_backward(self, img_shape, net, device, simple_var=True):
        x = torch.randn(img_shape).to(device)
        net = net.to(device)
        for t in range(self.n_steps-1, -1, -1):
            x = self.sample_backward_step(x, t, net, simple_var)
        return x
    def training_sample_backward_with_t(self, img_shape, net, device, t_end, simple_var=True):
        x = torch.randn(img_shape).to(device)
        net = net.to(device)
        for t in range(self.n_steps-1, t_end-1, -1):
            x = self.sample_backward_step(x, t, net, simple_var)
        return x

In [68]:
class Classifier(nn.Module):
    def __init__(self, output_size):
        super(Classifier, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, output_size)  # Output has 10 classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [69]:
def get_dataloader(batch_size: int):
    transform = transforms.Compose([
        transforms.ToTensor()    ])
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    
    target_label = 0
    filtered_indices = [i for i, label in enumerate(train_dataset.targets) if label == target_label]
    dataset = torch.utils.data.Subset(train_dataset, filtered_indices)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [70]:
def get_img_shape():
    return 1, 28, 28

In [71]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PositionalEncoding(nn.Module):

    def __init__(self, max_seq_len: int, d_model: int):
        super().__init__()

        # Assume d_model is an even number for convenience
        assert d_model % 2 == 0

        pe = torch.zeros(max_seq_len, d_model)
        i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
        j_seq = torch.linspace(0, d_model - 2, d_model // 2)
        pos, two_i = torch.meshgrid(i_seq, j_seq, indexing='ij')
        pe_2i = torch.sin(pos / 10000**(two_i / d_model))
        pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
        pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(max_seq_len, d_model)

        self.embedding = nn.Embedding(max_seq_len, d_model)
        self.embedding.weight.data = pe
        self.embedding.requires_grad_(False)

    def forward(self, t):
        return self.embedding(t)


class ResidualBlock(nn.Module):

    def __init__(self, in_c: int, out_c: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.actvation1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.actvation2 = nn.ReLU()
        if in_c != out_c:
            self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, 1),
                                          nn.BatchNorm2d(out_c))
        else:
            self.shortcut = nn.Identity()

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.actvation1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += self.shortcut(input)
        x = self.actvation2(x)
        return x


class ConvNet(nn.Module):

    def __init__(self,
                 n_steps,
                 intermediate_channels=[10, 20, 40],
                 pe_dim=10,
                 insert_t_to_all_layers=False):
        super().__init__()
        C, H, W = get_img_shape()  # 1, 28, 28
        self.pe = PositionalEncoding(n_steps, pe_dim)

        self.pe_linears = nn.ModuleList()
        self.all_t = insert_t_to_all_layers
        if not insert_t_to_all_layers:
            self.pe_linears.append(nn.Linear(pe_dim, C))

        self.residual_blocks = nn.ModuleList()
        prev_channel = C
        for channel in intermediate_channels:
            self.residual_blocks.append(ResidualBlock(prev_channel, channel))
            if insert_t_to_all_layers:
                self.pe_linears.append(nn.Linear(pe_dim, prev_channel))
            else:
                self.pe_linears.append(None)
            prev_channel = channel
        self.output_layer = nn.Conv2d(prev_channel, C, 3, 1, 1)

    def forward(self, x, t):
        n = t.shape[0]
        t = self.pe(t)
        for m_x, m_t in zip(self.residual_blocks, self.pe_linears):
            if m_t is not None:
                pe = m_t(t).reshape(n, -1, 1, 1)
                x = x + pe
            x = m_x(x)
        x = self.output_layer(x)
        return x


class UnetBlock(nn.Module):

    def __init__(self, shape, in_c, out_c, residual=False):
        super().__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.activation = nn.ReLU()
        self.residual = residual
        if residual:
            if in_c == out_c:
                self.residual_conv = nn.Identity()
            else:
                self.residual_conv = nn.Conv2d(in_c, out_c, 1)

    def forward(self, x):
        out = self.ln(x)
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        if self.residual:
            out += self.residual_conv(x)
        out = self.activation(out)
        return out


class UNet(nn.Module):

    def __init__(self,
                 n_steps,
                 channels=[10, 20, 40, 80],
                 pe_dim=10,
                 residual=False) -> None:
        super().__init__()
        C, H, W = get_img_shape()
        layers = len(channels)
        Hs = [H]
        Ws = [W]
        cH = H
        cW = W
        for _ in range(layers - 1):
            cH //= 2
            cW //= 2
            Hs.append(cH)
            Ws.append(cW)

        self.pe = PositionalEncoding(n_steps, pe_dim)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.pe_linears_en = nn.ModuleList()
        self.pe_linears_de = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        prev_channel = C
        for channel, cH, cW in zip(channels[0:-1], Hs[0:-1], Ws[0:-1]):
            self.pe_linears_en.append(
                nn.Sequential(nn.Linear(pe_dim, prev_channel), nn.ReLU(),
                              nn.Linear(prev_channel, prev_channel)))
            self.encoders.append(
                nn.Sequential(
                    UnetBlock((prev_channel, cH, cW),
                              prev_channel,
                              channel,
                              residual=residual),
                    UnetBlock((channel, cH, cW),
                              channel,
                              channel,
                              residual=residual)))
            self.downs.append(nn.Conv2d(channel, channel, 2, 2))
            prev_channel = channel

        self.pe_mid = nn.Linear(pe_dim, prev_channel)
        channel = channels[-1]
        self.mid = nn.Sequential(
            UnetBlock((prev_channel, Hs[-1], Ws[-1]),
                      prev_channel,
                      channel,
                      residual=residual),
            UnetBlock((channel, Hs[-1], Ws[-1]),
                      channel,
                      channel,
                      residual=residual),
        )
        prev_channel = channel
        for channel, cH, cW in zip(channels[-2::-1], Hs[-2::-1], Ws[-2::-1]):
            self.pe_linears_de.append(nn.Linear(pe_dim, prev_channel))
            self.ups.append(nn.ConvTranspose2d(prev_channel, channel, 2, 2))
            self.decoders.append(
                nn.Sequential(
                    UnetBlock((channel * 2, cH, cW),
                              channel * 2,
                              channel,
                              residual=residual),
                    UnetBlock((channel, cH, cW),
                              channel,
                              channel,
                              residual=residual)))

            prev_channel = channel

        self.conv_out = nn.Conv2d(prev_channel, C, 3, 1, 1)

    def forward(self, x, t):
        n = t.shape[0]
        t = self.pe(t)
        encoder_outs = []
        for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders,
                                            self.downs):
            pe = pe_linear(t).reshape(n, -1, 1, 1)
            x = encoder(x + pe)
            encoder_outs.append(x)
            x = down(x)
        pe = self.pe_mid(t).reshape(n, -1, 1, 1)
        x = self.mid(x + pe)
        
        mid_x = []
        for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de,
                                                       self.decoders, self.ups,
                                                       encoder_outs[::-1]):
            pe = pe_linear(t).reshape(n, -1, 1, 1)
            x = up(x)
            
            pad_x = encoder_out.shape[2] - x.shape[2]
            pad_y = encoder_out.shape[3] - x.shape[3]
            x = F.pad(x, (pad_x // 2, pad_x - pad_x // 2, pad_y // 2,
                          pad_y - pad_y // 2))
            x = torch.cat((encoder_out, x), dim=1)
            x = decoder(x + pe)

            mid_x.append(x)
        x = self.conv_out(x)
        return x, mid_x






def build_network(config: dict, n_steps):
    network_type = config.pop('type')
    if network_type == 'ConvNet':
        network_cls = ConvNet
    elif network_type == 'UNet':
        network_cls = UNet

    network = network_cls(n_steps, **config)
    return network

In [72]:
class Discriminator(nn.Module):
    def __init__(self, channels, n_size):
        super(Discriminator, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(channels, channels*2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),

            nn.Conv2d(channels*2, channels*4, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels*4*n_size*n_size, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [73]:
from tqdm import tqdm

In [80]:
batch_size = 4
n_epochs = 1

def train(ddpm: DDPM, net, device):
    n_steps = ddpm.n_steps
    dataloader = get_dataloader(batch_size)
    net = net.to(device)
    loss_fn = nn.MSELoss()
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(net.parameters())
    discriminator_1 = Discriminator(16, 3).to(device) # (-1, 16, 3, 3)
    discriminator_1_optimizer = torch.optim.Adam(discriminator_1.parameters())
    
    discriminator_2 = Discriminator(8, 7).to(device)
    discriminator_2_optimizer = torch.optim.Adam(discriminator_2.parameters())
    
    discriminator_3 = Discriminator(4, 14).to(device)
    discriminator_3_optimizer = torch.optim.Adam(discriminator_3.parameters())
    
    discriminator_4 = Discriminator(2, 28).to(device)
    discriminator_4_optimizer = torch.optim.Adam(discriminator_4.parameters())
    for e in range(n_epochs):
        tot_loss = 0
        count = 0
        for x, _ in tqdm(dataloader):
            current_batch_size = x.shape[0]
            x = x.to(device)
            t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
            eps = torch.randn_like(x).to(device)
            x_t = ddpm.sample_forward(x, t, eps)
            eps_theta, x_mid = net(x_t, t.reshape(current_batch_size, 1))
            
            #loss = loss_fn(eps_theta, eps)
            #optimizer.zero_grad()
            #loss.backward(retain_graph=True)
            #optimizer.step()
            
            # Generate true data            
            real_labels = torch.ones(batch_size, 1).to(device)
            real_outputs = discriminator_1(x_mid[0])
            loss_real_1 = criterion(real_outputs, real_labels)
            
            real_outputs = discriminator_2(x_mid[1])
            loss_real_2 = criterion(real_outputs, real_labels)
            
            real_outputs = discriminator_3(x_mid[2])
            loss_real_3 = criterion(real_outputs, real_labels)
            
            real_outputs = discriminator_4(x_mid[3])
            loss_real_4 = criterion(real_outputs, real_labels)
            
            # Generate fake data
            net.eval()
            x_t_fake = torch.zeros_like(x_t)
            count = 0
            for end_time in t:
                fake_img = ddpm.training_sample_backward_with_t((1, *get_img_shape()), net, device=device, t_end=end_time, simple_var=True)
                x_t_fake[count] = fake_img[0]
                count += 1
            eps_theta, x_mid_fake = net(x_t_fake, t.reshape(current_batch_size, 1))
            
            
            fake_labels = torch.zeros(batch_size, 1).to(device)
            fake_outputs = discriminator_1(x_mid_fake[0])
            loss_fake_1 = criterion(fake_outputs, fake_labels)
            
            fake_outputs = discriminator_2(x_mid_fake[1])
            loss_fake_2 = criterion(fake_outputs, fake_labels)
            
            fake_outputs = discriminator_3(x_mid_fake[2])
            loss_fake_3 = criterion(fake_outputs, fake_labels)
            
            fake_outputs = discriminator_4(x_mid_fake[3])
            loss_fake_4 = criterion(fake_outputs, fake_labels)
            
            
            discriminator_1_optimizer.zero_grad()
            loss_discriminator_1 = loss_real_1 + loss_fake_1
            loss_discriminator_1.backward(retain_graph=True)
            discriminator_1_optimizer.step()
            
            discriminator_2_optimizer.zero_grad()
            loss_discriminator_2 = loss_real_2 + loss_fake_2
            loss_discriminator_2.backward(retain_graph=True)
            discriminator_2_optimizer.step()
            
            discriminator_3_optimizer.zero_grad()
            loss_discriminator_3 = loss_real_3 + loss_fake_3
            loss_discriminator_3.backward(retain_graph=True)
            discriminator_3_optimizer.step()
            
            discriminator_4_optimizer.zero_grad()
            loss_discriminator_4 = loss_real_4 + loss_fake_4
            loss_discriminator_4.backward(retain_graph=True)
            discriminator_4_optimizer.step()
            net.train()
            
            # Train Generator
            discriminator_1.eval()
            discriminator_2.eval()
            discriminator_3.eval()
            discriminator_4.eval()
            discriminator_outputs = discriminator_1(x_mid_fake[0]) # can generate again but waste a lot of time
            loss_generator_1 = criterion(discriminator_outputs, real_labels)
            
            discriminator_outputs = discriminator_2(x_mid_fake[1]) # can generate again but waste a lot of time
            loss_generator_2 = criterion(discriminator_outputs, real_labels)
            
            discriminator_outputs = discriminator_3(x_mid_fake[2]) # can generate again but waste a lot of time
            loss_generator_3 = criterion(discriminator_outputs, real_labels)
            
            discriminator_outputs = discriminator_4(x_mid_fake[3]) # can generate again but waste a lot of time
            loss_generator_4 = criterion(discriminator_outputs, real_labels)
            
            optimizer.zero_grad()
            loss_generator = loss_generator_1 + loss_generator_2 + loss_generator_3 + loss_generator_4
            loss_generator.backward()
            optimizer.step()
            discriminator_1.train()
            discriminator_2.train()
            discriminator_3.train()
            discriminator_4.train()


            
            #tot_loss += loss
            #count += 1
            print(loss_discriminator_1.item(), loss_discriminator_2.item(), loss_discriminator_3.item(), 
                  loss_discriminator_4.item(), loss_generator.item())
            
        
        print(tot_loss/count, e)

In [81]:
n_steps = 1000
device = 'cuda'
unet_res_cfg = {
    'type': 'UNet',
    'channels': [2, 4, 8, 16, 32], # [32, 64, 128, 256, 512],
    'pe_dim': 128,
    'residual': True
}

config = unet_res_cfg 
net = build_network(config, n_steps)
ddpm = DDPM(device, n_steps)

In [None]:
train(ddpm, net, device=device)

  0%|                                                                              | 1/1481 [01:05<26:53:42, 65.42s/it]

1.3880209922790527 1.3862806558609009 1.3839762210845947 1.5438129901885986 3.9326210021972656


  0%|                                                                              | 2/1481 [02:05<25:33:59, 62.23s/it]

1.38771390914917 1.3849060535430908 1.3777949810028076 1.0692633390426636 4.616359710693359


  0%|▏                                                                             | 3/1481 [02:43<21:04:26, 51.33s/it]

1.3874378204345703 1.3810412883758545 1.3702280521392822 1.0705336332321167 4.48138952255249


  0%|▏                                                                             | 4/1481 [03:25<19:26:40, 47.39s/it]

1.3863416910171509 1.3838279247283936 1.3645434379577637 1.0742318630218506 4.131069183349609


  0%|▎                                                                             | 5/1481 [04:19<20:31:17, 50.05s/it]

1.3855984210968018 1.383650541305542 1.336582899093628 0.9154834747314453 5.040769577026367


  0%|▎                                                                             | 6/1481 [05:28<23:07:14, 56.43s/it]

1.3863197565078735 1.3784836530685425 1.3219192028045654 0.8948585987091064 5.821846008300781


  0%|▎                                                                             | 7/1481 [06:04<20:21:02, 49.70s/it]

1.385983943939209 1.3866535425186157 1.3561079502105713 1.1022971868515015 3.581782102584839


  1%|▍                                                                             | 8/1481 [07:44<26:55:17, 65.80s/it]

1.3837788105010986 1.3852208852767944 1.2800487279891968 0.5738208293914795 6.370873928070068


  1%|▍                                                                             | 9/1481 [08:42<25:54:43, 63.37s/it]

1.3863023519515991 1.3843553066253662 1.3216886520385742 0.9232012033462524 3.8806238174438477


  1%|▌                                                                            | 10/1481 [09:56<27:14:44, 66.68s/it]

1.3858613967895508 1.3851443529129028 1.3117635250091553 0.9311338663101196 3.6919233798980713


  1%|▌                                                                            | 11/1481 [11:00<26:48:50, 65.67s/it]

1.3860989809036255 1.3849958181381226 1.333944320678711 0.9668624401092529 3.496884346008301


  1%|▌                                                                            | 12/1481 [12:14<27:50:17, 68.22s/it]

1.3857671022415161 1.385061502456665 1.3266479969024658 1.011543869972229 3.3706042766571045


  1%|▋                                                                            | 13/1481 [13:00<25:07:35, 61.62s/it]

1.3862736225128174 1.386423110961914 1.3588330745697021 1.239640712738037 3.0034537315368652


  1%|▋                                                                            | 14/1481 [14:04<25:24:02, 62.33s/it]

1.3859378099441528 1.3853254318237305 1.349257469177246 1.244666337966919 2.6891374588012695


  1%|▊                                                                            | 15/1481 [15:19<26:54:45, 66.09s/it]

1.386293888092041 1.385263204574585 1.3417869806289673 1.2442529201507568 2.82623028755188


In [35]:
torch.save(net.state_dict(), "model_unet_res_01.pth")

In [10]:
net.load_state_dict(torch.load("model_unet_res_01.pth"))

<All keys matched successfully>

In [18]:
inv_normalize = transforms.Normalize(
    mean=[-0.4914/0.2463, -0.4821/0.2428, -0.4465/0.2607],
    std=[1/0.2463, 1/0.2428, 1/0.2607]
)

In [12]:
def sample_imgs(ddpm,
                net,
                output_path,
                n_sample=81,
                device='cuda',
                simple_var=True):
    net = net.to(device)
    net = net.eval()
    with torch.no_grad():
        shape = (n_sample, *get_img_shape())  # 1, 3, 28, 28
        imgs = ddpm.sample_backward(shape,
                                    net,
                                    device=device,
                                    simple_var=simple_var).detach().cpu()
        imgs = inv_normalize(imgs)
        imgs = imgs.clamp(0, 255)
        imgs = einops.rearrange(imgs,
                                '(b1 b2) c h w -> (b1 h) (b2 w) c',
                                b1=int(n_sample**0.5))

        imgs = imgs.numpy().astype(np.uint8)

        cv2.imwrite(output_path, imgs)

In [99]:
sample_imgs(ddpm,
                net,
                "img.png",
                n_sample=81,
                device='cuda',
                simple_var=True)

In [11]:
shape = (1, *get_img_shape())  # 10, 1, 28, 28

In [12]:
imgs = ddpm.sample_backward(shape,
                                    net,
                                    device=device,
                                    simple_var=True).detach().cpu()
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
plt.imshow(imgs.view(28, 28), cmap='gray')
plt.show()

torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size

torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size

torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size

torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size

torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size

torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size

torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size

torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size

torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size

torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size

torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size([1, 32, 28, 28])
torch.Size([1, 256, 3, 3])
torch.Size([1, 128, 7, 7])
torch.Size([1, 64, 14, 14])
torch.Size

NameError: name 'os' is not defined