# fMRI Transformer Conditional GAN

A conditional Transformer based time-series GAN that generates images from 1-D signal data over 360 channels, conditioned over 3 classes.

*Author's note: this is a reduced version of the full model intended for accessible use by others*

In [1]:
#@title Imports and Connect Drive (Optional)
# from google.colab import drive
# drive.mount('/content/drive/')

!pip install einops

import os, sys
from copy import deepcopy
from datetime import datetime
from pathlib import Path
import dill
path = Path("SAMPLE DIRECTORY PATH")
os.chdir(path)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

from tqdm import tqdm

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms, datasets
from torch import optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torch import Tensor
from torchsummary import summary
from torchvision.transforms import Compose, Resize, ToTensor

Mounted at /content/drive/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1


In [2]:
# @title Set device (GPU or CPU). Execute `set_device()`
# especially if torch modules used.

# Inform the user if the notebook uses GPU or CPU.

def set_device():
  """
  Set the device. CUDA if available, CPU otherwise

  Args:
    None

  Returns:
    Nothing
  """
  device = "cuda" if torch.cuda.is_available() else "cpu"
  if device != "cuda":
    print("WARNING: For this notebook to perform best, "
        "if possible, in the menu under `Runtime` -> "
        "`Change runtime type.`  select `GPU` ")
  else:
    print("GPU is enabled in this notebook.")

  return device

In [3]:
# @title Set random seed

# @markdown Executing `set_seed(seed=seed)` you are setting the seed

# For DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html

# Call `set_seed` function in the exercises to ensure reproducibility.
import random
import torch

def set_seed(seed=None, seed_torch=True):
  """
  Handles variability by controlling sources of randomness
  through set seed values

  Args:
    seed: Integer
      Set the seed value to given integer.
      If no seed, set seed value to random integer in the range 2^32
    seed_torch: Bool
      Seeds the random number generator for all devices to
      offer some guarantees on reproducibility

  Returns:
    Nothing
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


# In case that `DataLoader` is used
def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

In [4]:
# Set global variables
SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()

Random seed 2021 has been set.
GPU is enabled in this notebook.


# DATASET NOTE:
Due to the architecture of the TTS-GAN, we do not reduce the data to its mean, and instead opt for batch sizes of 33 to match the sequence of events that occur.

In [18]:
#@title Create Dataset

class FMRIDataset(Dataset):
    """HCP Gambling FMRI dataset."""

    def __init__(self, data_file="data_fmri_all_w_init.csv", 
                 labels_file="data_events_all_w_init.csv", 
                 root_dir="/content/drive/MyDrive/nmaproject/", 
                 transform=None,onehot=True):
        """
        Args:
            data_file (string): Path to the csv file with each subject's fMRI readings
            label_file(string): Path to the csv file with the outcome of each trial
            root_dir  (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        def f(x):
          if x == "init" or x == "neut": return 0
          elif x == "loss": return 1
          else: return 2

        self.root_dir = root_dir
        self.transform = transform
        
        data = pd.read_csv(self.root_dir + data_file).drop(columns="Unnamed: 0")
        data = torch.tensor(data.values)
        # OG: self.data = data.reshape(100,360,33,5).mean(dim=3).unsqueeze(2).type(torch.FloatTensor)
        self.data = data.reshape(100,360,33,5).reshape((33*100,360,1,5)).type(torch.FloatTensor)

        labels = pd.read_csv(self.root_dir + labels_file).drop(columns="Unnamed: 0")
        labels_mapped = labels.applymap(f)
        labels = torch.tensor(labels_mapped.values)
        # OG: self.labels = labels.reshape(100,33)
        self.labels = labels.reshape((33 * 100,))
        if onehot:
          self.labels = F.one_hot(self.labels) # OG: (100,33,3), NEW: (3300,3)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = self.data[idx] # OG: (360,1,33), NEW: (360,1,5)
        y = self.labels[idx] # OG: (33,3), NEW: (5) OR (5,3) if one-hot

        if self.transform:
            x = self.transform(x)
        
        sample = (x,y)

        return sample

def custom_normalize(x):
    """ A custom normalization method
        Returns
            result: a normalized epoch
    """
    e = 1e-10
    result = (x - x.mean(axis=0)) / ((torch.sqrt(x.var(axis=0)))+e)
    return result

### Defining Parameters and Models

In [19]:
# training_parameters
cond = True # conditional GAN vs regular GAN
shrink = 10
pars = {
    "n_epochs": 11,#1000,
    "batch_size": 33,
    "g_lr": 0.0001,
    "d_lr": 0.0003, # discriminator learning rate
    "wd": 1e-3,     # weight decay
    "beta1": 0.9,  # for adam
    "beta2": 0.999, # for adam
    "n_critic": 1,
    "latent":(3*360*33 // 2) // shrink, # due to RAM constraints
    "embed_dim":(3*36*33 // shrink) + (3*36*33 % 5),
    "dis_embed_dim":(3*36*33 // shrink) + (3*36*33 % 5),#(330) // shrink,
    "ema_kimg":500,
    "ema_warmup":0.1,
    "ema":0.9999,
    "global_steps":0,
    "patch_size":1, #3
    "seq_len":5,
}

dataset = FMRIDataset(transform=custom_normalize,onehot=False) # onehot should be False for CGAN
data_loader = DataLoader(dataset, batch_size=pars["batch_size"], shuffle=True)

In [11]:
#@title Define Models
class Generator(nn.Module):
    def __init__(self, seq_len=pars["seq_len"], channels=360, num_classes=3, latent_dim=pars["latent"], embed_dim=pars["embed_dim"], depth=3,
                 num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5):
        super(Generator, self).__init__()
        self.channels = channels
        self.latent_dim = latent_dim
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.depth = depth
        self.attn_drop_rate = attn_drop_rate
        self.forward_drop_rate = forward_drop_rate
        
        self.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim))
        self.blocks = Gen_TransformerEncoder(
                         depth=self.depth,
                         emb_size = self.embed_dim,
                         drop_p = self.attn_drop_rate,
                         forward_drop_p=self.forward_drop_rate
                        )

        self.deconv = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0)
        )
        
    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, z):
        x = self.l1(z).view(-1, self.seq_len, self.embed_dim)
        x = x + self.pos_embed
        H, W = 1, self.seq_len
        x = self.blocks(x)
        x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
        output = self.deconv(x.permute(0, 3, 1, 2))
        output = output.view(-1, self.channels, H, W)
        return output

# Conditional GAN
class CGenerator(nn.Module):
    def __init__(self, seq_len=pars["seq_len"], channels=360, num_classes=3, latent_dim=pars["latent"], data_embed_dim=pars["embed_dim"], 
                label_embed_dim=pars["embed_dim"] ,depth=3, num_heads=5, 
                forward_drop_rate=0.5, attn_drop_rate=0.5):

        super(CGenerator, self).__init__()
        self.seq_len = seq_len
        self.channels = channels
        self.num_classes = num_classes
        self.latent_dim = latent_dim
        self.data_embed_dim = data_embed_dim
        self.label_embed_dim = label_embed_dim
        self.depth = depth
        self.num_heads = num_heads
        self.attn_drop_rate = attn_drop_rate
        self.forward_drop_rate = forward_drop_rate
        
        self.l1 = nn.Linear(self.latent_dim + self.label_embed_dim, self.seq_len * self.data_embed_dim)
        self.label_embedding = nn.Embedding(self.num_classes, self.label_embed_dim) 
        
        self.blocks = Gen_TransformerEncoder(
                 depth=self.depth,
                 emb_size = self.data_embed_dim,
                 num_heads = self.num_heads,
                 drop_p = self.attn_drop_rate,
                 forward_drop_p=self.forward_drop_rate
                )

        self.deconv = nn.Sequential(
            nn.Conv2d(self.data_embed_dim, self.channels, 1, 1, 0)
        )
        
    def forward(self, z, labels):
        c = self.label_embedding(labels)
        x = torch.cat([z, c], 1)
        x = self.l1(x)
        x = x.view(-1, self.seq_len, self.data_embed_dim)
        H, W = 1, self.seq_len
        x = self.blocks(x)
        x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
        output = self.deconv(x.permute(0, 3, 1, 2))
        return output

    @property
    def device(self):
        return next(self.parameters()).device
    
class Gen_TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=5,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

        
class Gen_TransformerEncoder(nn.Sequential):
    def __init__(self, depth=8, **kwargs):
        super().__init__(*[Gen_TransformerEncoderBlock(**kwargs) for _ in range(depth)])       
        
        
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

    
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
    
    
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

        
        
class Dis_TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size=100,
                 num_heads=5,
                 drop_p=0.,
                 forward_expansion=4,
                 forward_drop_p=0.):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class Dis_TransformerEncoder(nn.Sequential):
    def __init__(self, depth=8, **kwargs):
        super().__init__(*[Dis_TransformerEncoderBlock(**kwargs) for _ in range(depth)])
        
        
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size=100, n_classes=2):
        super().__init__()
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )

    def forward(self, x):
        out = self.clshead(x)
        return out

class C_ClassificationHead(nn.Sequential):
    def __init__(self, emb_size=100, adv_classes=2, cls_classes=10):
        super().__init__()
        self.adv_head = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, adv_classes)
        )
        self.cls_head = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, cls_classes)
        )

    def forward(self, x):
        out_adv = self.adv_head(x)
        out_cls = self.cls_head(x)
        return out_adv, out_cls
    
class PatchEmbedding_Linear(nn.Module):
    #what are the proper parameters set here?
    def __init__(self, in_channels = 21, patch_size = 16, emb_size = 100, seq_length = 1024):
        # self.patch_size = patch_size
        super().__init__()
        #change the conv2d parameters here
        self.projection = nn.Sequential(
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)',s1 = 1, s2 = patch_size),
            nn.Linear(patch_size*in_channels, emb_size)
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((seq_length // patch_size) + 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        #prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # position
        x += self.positions
        return x        
        
class Discriminator(nn.Sequential): # for CGAN parent from nn.Module
    def __init__(self, 
                 in_channels=360,
                 patch_size=pars["patch_size"],
                 emb_size=pars["dis_embed_dim"], 
                 seq_length = pars["seq_len"],
                 depth=3, 
                 n_classes=1, #before 3 
                 **kwargs):

        super().__init__(
              PatchEmbedding_Linear(in_channels, patch_size, emb_size, seq_length),
              Dis_TransformerEncoder(depth, emb_size=emb_size, drop_p=0.5, forward_drop_p=0.5, **kwargs),
              ClassificationHead(emb_size, n_classes),
          )
        
        @property
        def device(self):
            return next(self.parameters()).device

class CDiscriminator(nn.Sequential):
    def __init__(self, 
                 in_channels=360,
                 patch_size=pars["patch_size"],
                 data_emb_size=pars["dis_embed_dim"],
                 label_emb_size=pars["dis_embed_dim"],
                 seq_length = pars["seq_len"],
                 depth=3, 
                 n_classes=3, 
                 **kwargs):
        super().__init__(
            PatchEmbedding_Linear(in_channels, patch_size, data_emb_size, seq_length),
            Dis_TransformerEncoder(depth, emb_size=data_emb_size, drop_p=0.5, forward_drop_p=0.5, **kwargs),
            C_ClassificationHead(data_emb_size, 1, n_classes)
        )
    @property
    def device(self):
        return next(self.parameters()).device


### Training

### Training Loop

In [33]:
load_model = False

In [34]:
#@title Load Model
if load_model:
  if cond:
    model_name = "tts-cgan"
  else:
    model_name = "tts-gan"
  checkpoint = torch.load(path / f'{model_name}_checkpoint.pth',map_location=DEVICE,pickle_module=dill)
else:
  checkpoint = {}

In [35]:
#@title Helper Functions

class LinearLrDecay(object):
    def __init__(self, optimizer, start_lr, end_lr, decay_start_step, decay_end_step):

        assert start_lr > end_lr
        self.optimizer = optimizer
        self.delta = (start_lr - end_lr) / (decay_end_step - decay_start_step)
        self.decay_start_step = decay_start_step
        self.decay_end_step = decay_end_step
        self.start_lr = start_lr
        self.end_lr = end_lr

    def step(self, current_step):
        if current_step <= self.decay_start_step:
            lr = self.start_lr
        elif current_step >= self.decay_end_step:
            lr = self.end_lr
        else:
            lr = self.start_lr - self.delta * (current_step - self.decay_start_step)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        return lr
        
def copy_params(model, mode='cpu'):
    if mode == 'gpu':
        flatten = []
        for p in model.parameters():
            cpu_p = deepcopy(p).cpu()
            flatten.append(cpu_p.data)
    else:
        flatten = deepcopy(list(p.data for p in model.parameters()))
    return flatten

def gradient_penalty(y, x):
    """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
    weight = torch.ones(y.size()).to(DEVICE)
    dydx = torch.autograd.grad(outputs=y,
                               inputs=x,
                               grad_outputs=weight,
                               retain_graph=True,
                               create_graph=True,
                               only_inputs=True)[0]

    # dydx = dydx.view(dydx.size(0), -1)
    dydx = dydx.reshape(dydx.size(0), -1)
    dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
    return torch.mean((dydx_l2norm-1)**2)    

In [36]:
#@title Training Loop
def train(data_loader, checkpoint = {}, cond=False):
    if checkpoint:
      generator = checkpoint["gen_net"]
      discriminator = checkpoint["dis_net"]
      gen_optimizer = checkpoint["gen_optimizer"]
      dis_optimizer = checkpoint["dis_optimizer"]
      gen_scheduler = checkpoint["gen_scheduler"]
      dis_scheduler = checkpoint["dis_scheduler"]
      start_epoch = int(checkpoint["epoch"])
      global_steps = start_epoch * len(data_loader)
    else:
      if cond:
        generator = CGenerator().to(DEVICE)
        discriminator = CDiscriminator().to(DEVICE)
      else:
        generator = Generator().to(DEVICE)
        discriminator = Discriminator().to(DEVICE)
      gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()),
                                          pars["g_lr"], (pars["beta1"], pars["beta2"]))
      dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),
                                      pars["d_lr"], (pars["beta1"], pars["beta2"]))
      gen_scheduler = LinearLrDecay(gen_optimizer, pars["g_lr"], 0.0, 0, pars["n_epochs"] * pars["n_critic"])
      dis_scheduler = LinearLrDecay(dis_optimizer, pars["d_lr"], 0.0, 0, pars["n_epochs"] * pars["n_critic"])
      start_epoch = 0
      global_steps = 0

    if cond:
      model_name="tts-cgan"
    else:
      model_name="tts-gan"

    gen_avg_param = copy_params(generator)
    n_epochs = pars["n_epochs"]

    ## FOR CGAN
    # adv_criterion = nn.BCELoss()
    lambda_cls = 1 # same as paper
    lambda_gp = 10
    cls_criterion = nn.CrossEntropyLoss()

    # Set to Train Mode
    generator.train()
    discriminator.train()
    
    for epoch_idx in range(start_epoch,n_epochs):
        for iter_idx, (imgs, labels) in enumerate(tqdm(data_loader)):
          # Adversarial ground truths
            real_imgs = imgs.type(torch.FloatTensor).to(DEVICE)
            real_img_labels = labels.type(torch.LongTensor).to(DEVICE)
            noise = torch.randn(pars["batch_size"],pars["latent"]).to(DEVICE) # same as 'z' in TTS-GAN
            fake_labels = torch.randint(0, 3, (pars["batch_size"],)).to(DEVICE) # the 3 represents the number of classes

            #Train Discriminator
            discriminator.zero_grad()

            if cond:
              r_out_adv, r_out_cls = discriminator(real_imgs)
              fake_imgs = generator(noise, fake_labels)
            else:
              real_validity = discriminator(real_imgs)
              fake_imgs = generator(noise).detach()

            assert fake_imgs.size() == real_imgs.size(), f"fake_imgs.size(): {fake_imgs.size()} real_imgs.size(): {real_imgs.size()}"


            if cond:
              f_out_adv, f_out_cls = discriminator(fake_imgs)
              # Compute loss for gradient penalty.
              alpha = torch.rand(real_imgs.size(0), 1, 1, 1).to(DEVICE)  # bh, C, H, W
              x_hat = (alpha * real_imgs.data + (1 - alpha) * fake_imgs.data).requires_grad_(True)
              out_src, _ = discriminator(x_hat)
              d_loss_gp = gradient_penalty(out_src, x_hat)
              
              d_real_loss = -torch.mean(r_out_adv)
              d_fake_loss = torch.mean(f_out_adv)
              d_adv_loss = d_real_loss + d_fake_loss 
              
              d_cls_loss = cls_criterion(r_out_cls, real_img_labels)
              
              d_loss = d_adv_loss + lambda_cls * d_cls_loss + lambda_gp * d_loss_gp
              d_loss.backward()
              
              torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 5.)
              dis_optimizer.step()
            else:
              fake_validity = discriminator(fake_imgs)

              real_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 1., dtype=torch.float, device=DEVICE)
              fake_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 0., dtype=torch.float, device=DEVICE)
              d_real_loss = nn.MSELoss()(real_validity, real_label)
              d_fake_loss = nn.MSELoss()(fake_validity, fake_label)
              d_loss = d_real_loss + d_fake_loss

              # skip accumulated_times since default is 1
              d_loss.backward()

              torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 5.)
              dis_optimizer.step()
              dis_optimizer.zero_grad()

            # Train Generator
            generator.zero_grad()

            if cond:
              gen_imgs = generator(noise, fake_labels)
              g_out_adv, g_out_cls = discriminator(gen_imgs)

              g_adv_loss = -torch.mean(g_out_adv)
              g_cls_loss = cls_criterion(g_out_cls, fake_labels)    
              g_loss = g_adv_loss + lambda_cls * g_cls_loss
              g_loss.backward()

              torch.nn.utils.clip_grad_norm_(generator.parameters(), 5.)
              gen_optimizer.step()
            else:
              # since n_critic also defaults to 1, train every iteration
              gen_z = torch.randn(pars["batch_size"],pars["latent"]).to(DEVICE)
              gen_imgs = generator(gen_z)
              fake_validity = discriminator(gen_imgs)

              # calculate loss
              loss_lz = torch.tensor(0)
              real_label = torch.full((fake_validity.shape[0],fake_validity.shape[1]), 1., dtype=torch.float, device=DEVICE)
              g_loss = nn.MSELoss()(fake_validity, real_label)
              # skip accumulated_times since default is 1
              g_loss.backward()

              torch.nn.utils.clip_grad_norm_(generator.parameters(), 5.)
              gen_optimizer.step()
              gen_optimizer.zero_grad()

            # schedulers
            g_lr = gen_scheduler.step(global_steps)
            d_lr = dis_scheduler.step(global_steps)

            # moving average weight
            ema_nimg = pars["ema_kimg"] * 1000
            cur_nimg = pars["batch_size"] * global_steps

            if pars["ema_warmup"] != 0:
                ema_nimg = min(ema_nimg, cur_nimg * pars["ema_warmup"])
                ema_beta = 0.5 ** (float(pars["batch_size"]) / max(ema_nimg, 1e-8))
            else:
                ema_beta = pars["ema"]
            
            # moving average weight
            for p, avg_p in zip(generator.parameters(), gen_avg_param):
                cpu_p = deepcopy(p)
                avg_p.mul_(ema_beta).add_((1. - ema_beta + cpu_p.cpu().data).to(DEVICE))
                del cpu_p

        tqdm.write(
                f"[Epoch {epoch_idx + 1}/{n_epochs}] [Batch {(iter_idx % len(data_loader) + 1)}/{len(data_loader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}] [ema: {ema_beta:.4f}] ")
        torch.save(generator, path / f'./{model_name}_generator.pt')
        torch.save(discriminator, path / f'./{model_name}_discriminator.pt')
        if (epoch_idx and epoch_idx % 10 == 0) or epoch_idx == pars["n_epochs"] - 1:
          checkpoint["epoch"] = epoch_idx
          checkpoint['gen_net'] = generator
          checkpoint['dis_net'] = discriminator
          checkpoint['gen_scheduler'] = gen_scheduler
          checkpoint['dis_scheduler'] = dis_scheduler
          checkpoint['gen_optimizer'] = gen_optimizer
          checkpoint['dis_optimizer'] = dis_optimizer
          torch.save(checkpoint, path / f'{model_name}_checkpoint.pth',pickle_module=dill)
          print("Checkpoint Saved")

In [37]:
train(data_loader,checkpoint=checkpoint,cond=cond)

100%|██████████| 100/100 [00:12<00:00,  7.86it/s]


[Epoch 2/11] [Batch 100/100] [D loss: -13.4343] [G loss: 4.7232] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.80it/s]


[Epoch 3/11] [Batch 100/100] [D loss: 0.8747] [G loss: -2.3898] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.93it/s]


[Epoch 4/11] [Batch 100/100] [D loss: 1.2891] [G loss: 0.1110] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.92it/s]


[Epoch 5/11] [Batch 100/100] [D loss: 1.0480] [G loss: 1.2922] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.90it/s]


[Epoch 6/11] [Batch 100/100] [D loss: 1.5584] [G loss: 0.5161] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.88it/s]


[Epoch 7/11] [Batch 100/100] [D loss: 1.6572] [G loss: 0.4290] [ema: 0.0000] 


100%|██████████| 100/100 [00:13<00:00,  7.33it/s]


[Epoch 8/11] [Batch 100/100] [D loss: 2.2454] [G loss: -0.7184] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.96it/s]


[Epoch 9/11] [Batch 100/100] [D loss: 2.5859] [G loss: -1.2685] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.86it/s]


[Epoch 10/11] [Batch 100/100] [D loss: 2.2155] [G loss: -1.4504] [ema: 0.0000] 


100%|██████████| 100/100 [00:12<00:00,  7.86it/s]


[Epoch 11/11] [Batch 100/100] [D loss: 1.9309] [G loss: -1.6242] [ema: 0.0000] 
Checkpoint Saved
