In [1]:
import numpy as np
import torch
import torchvision

from dataset import WeizmannHumanActionVideo
from mocogan import Generator, ImageDiscriminator, VideoDiscriminator, RNN, weights_init_normal

In [2]:
# use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device: ", device)

device:  cuda:0


In [None]:
trans_data = torchvision.transforms.ToTensor()
trans_label = None
dataset = WeizmannHumanActionVideo(trans_data=None, trans_label=trans_label, train=True)

# train-test split
train_size = int(1.0 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
print("train: ", len(train_dataset))
print("test: ", len(test_dataset))

train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=4)

In [None]:
# ===================
# Params
batch_size=1
n_epochs=100
T = 16
n_channel = 3
dim_zc = 2
dim_zm = 2
dim_e  = 16
# ===================

In [None]:
# model
netR  = RNN(dim_zm=dim_zm, dim_e=dim_e).to(device)
netG  = Generator(n_channel=n_channel, dim_zm=dim_zm, dim_zc=dim_zc).to(device)
netDI = ImageDiscriminator(n_channel=n_channel, dim_zm=dim_zm, dim_zc=dim_zc).to(device)
netDV = VideoDiscriminator(n_channel=n_channel, dim_zm=dim_zm, dim_zc=dim_zc).to(device)

# Initialize model weights
netR.apply(weights_init_normal)
netG.apply(weights_init_normal)
netDI.apply(weights_init_normal)
netDV.apply(weights_init_normal)

# Optimizers
optim_R  = torch.optim.Adam(netR.parameters(),  lr=0.0002, betas=(0.5, 0.999))
optim_G  = torch.optim.Adam(netG.parameters(),  lr=0.0002, betas=(0.5, 0.999))
optim_DI = torch.optim.Adam(netDI.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_DV = torch.optim.Adam(netDV.parameters(), lr=0.0002, betas=(0.5, 0.999))

# criterion = torch.nn.MSELoss()
criterion = torch.nn.BCELoss(reduction='mean')
                                                               
real_label = 1
fake_label = 0

In [25]:
def S_1(video):
    """
    video: torch.Tensor()
        (batch_size, video_len, channel, height, width)
    image: torch.Tensor()
        (batch_size, channel, height, width)
    """
    idx = int(np.random.rand() * video.shape[1])
    image =  torch.unsqueeze(torch.squeeze(video[:, idx:idx+1, :, :, :]), dim=0)
    return image

In [26]:
def S_T(video, T):
    """
    video: torch.Tensor()
        (batch_size, video_len, channel, height, width)
    """
    idx = int(np.random.rand() * (video.shape[1] - T))
    return video[:, idx:idx+T, :, :, :]

In [27]:
video = torch.randn(1, 50, 3, 96, 96)
video.shape

torch.Size([1, 50, 3, 96, 96])

In [28]:
y = S_T(video, T=5)

In [29]:
y.shape

torch.Size([1, 5, 3, 96, 96])

In [30]:
y = S_1(video)

In [31]:
y.shape

torch.Size([1, 3, 96, 96])

In [None]:
def train_model(epoch):
    netG.train()
    netG.to(device)
    netDI.train()
    netDI.to(device)
    netDV.train()
    netDV.to(device)
    
    train_loss_G = 0
    train_loss_DI = 0

    for batch_idx, (batch_data, _) in enumerate(train_loader):
        # data format
        batch_size, video_len, channel, height, width = batch_data.shape
        
        # =====================================
        # (1) Update DI, DV network: 
        #     maximize   log ( DI ( SI(x) ) ) + log(1 - DI ( SI ( G(z) ) ) )
        #              + log ( DV ( SV(x) ) ) + log(1 - DV ( SV ( G(z) ) ) )
        # =====================================

        ## ------------------------------------
        ## Train with all-real batch
        ## ------------------------------------
        netDI.zero_grad()
        netDV.zero_grad()
        
        # v_real: (batch_size=1, video_len, channel, height, width)
        v_real = batch_data.to(device) 

        label_DI = torch.full((batch_size,),   real_label).to(device)
        label_DV = torch.full((batch_size, T), real_label).to(device)

        # Forward pass real batch through D
        output_DI = netDI(S_1(v_real))
        output_DV = netDV(S_T(v_real, T))
        
        # Calculate loss on all-real batch
        loss_DI_real = criterion(output_DI, label_DI)
        loss_DV_real = criterion(output_DV, label_DV)
        loss_D_real  = loss_DI_real + loss_DV_real

        # Calculate gradients for D in backward pass
        loss_D_real.backward()


        ## ------------------------------------
        ## Train with all-fake batch
        ## ------------------------------------
        zc = torch.randn(batch_size, 1, self.dim_zc).repeat(1, video_len, 1).to(device)
        e  = torch.randn(batch_size, video_len, dim_e).to(device)
        zm = RNN(e)
        
        # v_fake: (batch_size, video_len, channel, height, width)
        v_fake = netG(zc, zm) 
        
        label_DI.fill_(fake_label)
        label_DV.fill_(fake_label)

        # Forward pass real batch through D
        output_DI = netDI(S_1(v_fake))
        output_DV = netDV(S_T(v_fake, T))
        
        # Calculate loss on all-real batch
        loss_DI_fake = criterion(output_DI, label_DI)
        loss_DV_fake = criterion(output_DV, label_DV)
        loss_D_fake  = loss_DI_fake + loss_DV_fake

        # Calculate gradients for D in backward pass
        loss_D_fake.backward()

        # Sum
        loss_D = loss_D_real + loss_D_fake 

        # Update DI, DV
        optim_DI.step()
        optim_DV.step()


        # =====================================
        # (2) Update G, R network: 
        #     maximize  log(DI ( SI ( G(z) ) ) )
        #             + log(DV ( SV ( G(z) ) ) )
        # =====================================

        netR.zero_grad()
        netG.zero_grad()

        label_DI.fill_(real_label)
        label_DV.fill_(real_label)

        # Forward pass real batch through D
        output_DI = netDI(S_1(v_fake))
        output_DV = netDV(S_T(v_fake, T))

        # Calculate loss on all-real batch
        loss_G_fake = criterion(output_DI, label_DI)
        loss_G_fake = criterion(output_DV, label_DV)
        loss_G_fake = loss_G_fake + loss_G_fake

        # Calculate gradients for D in backward pass
        loss_G_fake.backward()

        # Sum
        optim_R.step()
        optim_G.step()

        print(epoch)


if __name__ == "__main__":
    for epoch in range(n_epochs):
        train_model(epoch)
        eval_model(epoch)
        model_path = './trained_models/mocogan'+str(epoch)+'.pth'
        torch.save(model.to('cpu').state_dict(), model_path)
                                            
