# Train the sequence model on raw data

### Embed arts and shadings

1. train the embed model with custom dataset

In [None]:
# sanity check: how the transformed image looks like
from dataset.rawdata import ImageDataset
import matplotlib.pyplot as plt
from torchvision import transforms
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((100, 100))
    ])
data = ImageDataset(transform, '../data/asset/art', '../data/asset/shading')
plt.imshow(data[0].permute((1, 2, 0)))
plt.show()

In [None]:
# compare with how the original picture looks like
from skimage import io
img0 = io.imread(data.img_paths[0])
plt.imshow(img0[:,:,:3]) # remove alpha channel
plt.show()

In [None]:
!python embedding/embedding_main.py --data-set raw --lr 0.001 --epochs 200

2. load encoder

In [None]:
!pip install torch==1.9.0
!pip install torchvision==0.10.0

In [None]:
import torch
from embedding.models.embed_model import ConvEncoderDecoder

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
convcoder = ConvEncoderDecoder().to(device)
convcoder.load_state_dict(torch.load('output/convcoder_raw.pt', map_location=device))

3. get embeddings of foreground images and background images

In [None]:
import glob
from skimage import io
from torchvision import transforms 

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((100, 100))
    ])

def encode(image, convcoder):
    image = transform(image)
    image = image.unsqueeze(dim=0).to(device)
    with torch.no_grad():
        emb = convcoder.encoder(image)
        emb = torch.flatten(emb, start_dim=1, end_dim=-1)
        emb = convcoder.embedder(emb)
        emb = emb.cpu()
    return emb

def get_embeddings(image_dir, encoder):
    '''
    Parameters:
    image_dir: str
    encoder: nn.Module

    Returns:
    a dict of {image_number:embedding} pair
    '''
    
    image_paths = glob.glob(image_dir + '/*.png')
    assert len(image_paths) != 0, 'No png image found'
    # TODO: use scikitimage to read image to keep rgb order
    images = {int(path.split('/')[-1].split('.')[0]): io.imread(path) for path in image_paths}
    encoded_images = {image_number: encode(image[:,:,:3], encoder)[0] for image_number, image in images.items()}
    return encoded_images

In [None]:
fg_embs = get_embeddings('../data/asset/art', convcoder)
bg_embs = get_embeddings('../data/asset/shading', convcoder)

## Load data for training

A backup: 1st verison definition of `build_raw_data`.
```python
def build_raw_data(bg_embs, fg_embs, img_manifest_path):
    with open(img_manifest_path) as f:
        img_list = [json.loads(line.strip()) for line in f.readlines()]
    training_data = []
    for img_info in img_list:
        # record last x and y for calculating relative position
        x_last = int(img_info['fg'][0][2])
        y_last = int(img_info['fg'][0][3])
        fg_reps = []
        for fg in img_info['fg']:
            number, _, x, y, scale, rotate, opaque = fg
            # convert rotate in [0, 360] to [0, 1]
            rotate /= 360
            fg_emb = fg_embs[number]
            x_rel, y_rel = x - x_last, y - y_last
            x_last, y_last = x, y
            # normalize x_rel and y_rel to [-1, 1]
            fg_meta = torch.tensor([*normalize_relative_xy(x_rel, y_rel), scale, rotate, opaque])
            fg_reps.append(torch.cat((fg_emb, fg_meta)))
        fg_reps = torch.stack(fg_reps)
        fg_reps = fg_reps.unsqueeze(dim=1) # (steps, mb_size=1, fg_emb_dim+5), 1 is the batch_size
        # normalize class label from [-1, 1] to [0, 1] (TODO: -1 and 0 are both 0)
        cls_label = 0 if int(img_info['flag']) < 0 else 1
        cls_label = torch.tensor(cls_label, dtype=torch.long).unsqueeze(dim=0) # (mb_size=1, fg_emb_dim)
        bk_emb = bg_embs[img_info['bg']].unsqueeze(dim=0)
        training_data.append((fg_reps, bk_emb, cls_label))
    return training_data

```

In [None]:
import json
from collections import defaultdict

def normalize_relative_xy(x, y):
    '''
    normalize x, y from [-402, 402] and [-600, 600] to [-1, 1]
    magic number comes from inspection of raw data
    '''
    return x/402, y/600

def build_raw_data(bg_embs, fg_embs, img_manifest_path, batch_size):
    '''
    Generate training data with a meta file and corresponding embeddings
    
    Parameters:
    bg_embs: dict 
        background embedding dictionary. key: number; entry: tensor of shape (emb_size, )
    fg_embs: dict
        foreground embedding dictionary. key: number; entry: tensor of shape (emb_size, )
    img_manifest_path: str
    batch_size: int
    '''
    with open(img_manifest_path) as f:
        img_list = [json.loads(line.strip()) for line in f.readlines()]
    # group image by foreground art sequence length
    img_group = defaultdict(list)
    for img_info in img_list:
        img_group[len(img_info['fg'])].append(img_info)
    
    training_data = []
    # process data according to their different foreground sequence length
    for fg_len, imgs in img_group.items():
        # create mini-batch for data of a certain foreground sequence length
        for i in range(0, len(imgs), batch_size):
            fg_reps_batch = []
            cls_label_batch = []
            bk_emb_batch = []
            for img_info in imgs[i:i+batch_size]:
                # record last x and y for calculating relative position
                x_last = int(img_info['fg'][0][2])
                y_last = int(img_info['fg'][0][3])
                fg_reps = []
                for fg in img_info['fg']:
                    number, _, x, y, scale, rotate, opaque = fg
                    # convert rotate in [0, 360] to [0, 1]
                    rotate /= 360
                    fg_emb = fg_embs[number]
                    x_rel, y_rel = x - x_last, y - y_last
                    x_last, y_last = x, y
                    # normalize x_rel and y_rel to [-1, 1]
                    fg_meta = torch.tensor([*normalize_relative_xy(x_rel, y_rel), scale, rotate, opaque])
                    fg_reps.append(torch.cat((fg_emb, fg_meta)))
                fg_reps = torch.stack(fg_reps)
                fg_reps = fg_reps.unsqueeze(dim=1) # (steps, mb_size=1, fg_emb_dim+5), 1 is the batch_size
                # normalize class label from [-1, 1] to [0, 1] (TODO: -1 and 0 are both 0)
                cls_label = 0 if int(img_info['flag']) < 0 else 1
                cls_label = torch.tensor(cls_label, dtype=torch.long).unsqueeze(dim=0) # (mb_size=1, fg_emb_dim)
                bk_emb = bg_embs[img_info['bg']].unsqueeze(dim=0)
                fg_reps_batch.append(fg_reps)
                cls_label_batch.append(cls_label)
                bk_emb_batch.append(bk_emb)
            fg_reps_batch = torch.cat(fg_reps_batch, dim=1)
            cls_label_batch = torch.cat(cls_label_batch, dim=0)
            bk_emb_batch = torch.cat(bk_emb_batch, dim=0)
            training_data.append((fg_reps_batch, bk_emb_batch, cls_label_batch))
    return training_data

In [None]:
train_data = build_raw_data(bg_embs, fg_embs, '../../02_data/data.txt', 128)

## Train the model

0. define hyper parameters

In [None]:
args = dict()
args['seq_in_dim'] = 581
args['input_hid_size'] = 576
args['hid_dim'] = 256
args['num_layers'] = 3
args['lr'] = 0.001
args['wd'] = 1e-6
args['epochs'] = 700

1. examine training data

In [None]:
print("Input sequence has shape {}".format(train_data[0][0].shape))
print("Background embedding has shape {}".format(train_data[0][1].shape))
print("Class labels has shape {}".format(train_data[0][2].shape))

2. define models(s)

In [None]:
from sequence.models.seq_model import sequence_model, seq_loss_fn

seq_model = sequence_model(input_size=args['seq_in_dim'],
                           input_hid_size=args['input_hid_size'],
                           hidden_size=args['hid_dim'],
                           num_layers=args['num_layers'])
seq_model = seq_model.to(device)

3. define loss and optimizer
   
   will direct use the customized loss function.

In [None]:
optim = torch.optim.Adam(seq_model.parameters(), lr=args['lr'], weight_decay=args['wd'])

    define tensorboard writer

In [None]:
from datetime import datetime
import os
from torch.utils.tensorboard import SummaryWriter

now = datetime.today()
dt= now.strftime("%m_%d_%H_%M")
writer = SummaryWriter(os.path.join('./runs', dt))
writer.add_text('Parameters', str(args))

4. training loop

In [None]:
for epoch in range(args['epochs']):
    seq_model.train()
    total_loss = 0
    total_samples = 0
    for i, (in_seqs, bk_embs, cls_labels) in enumerate(train_data):
        in_seqs = in_seqs.to(device)
        bk_embs = bk_embs.to(device)
        cls_labels = cls_labels.to(device)

        h_0 = torch.stack([bk_embs for _ in range(args['num_layers'])]).to(device)
        # print("Hidden 0 shape {}".format(h_0.shape))

        c_0 = torch.zeros_like(h_0).to(device)

        out_seqs_logits, cls_logits = seq_model(in_seqs[:-1,], (h_0, c_0), return_last_hidden=False)
        
        loss = seq_loss_fn(out_seqs_logits, in_seqs[1:,], cls_logits, cls_labels, alpha=0.2)
        total_loss += loss.item()
        total_samples += len(cls_labels)
        writer.add_scalar('batch/train_loss', loss.item()/len(cls_labels), global_step=epoch*len(train_data)+i)
        

        optim.zero_grad()
        loss.backward()
        optim.step()
        
    writer.add_scalar('epoch/train_loss', total_loss, global_step=epoch)
    print("In epoch: {:03d} | loss: {:.6f}".format(epoch, total_loss))

## Generate Images

Define meta data for generating new images.

In [None]:
fg_seq_len = 20
init_bg = 8
init_fg_meta = (15,0,-134,374,0.85,327,1)

In [None]:
fg_emb_0 = torch.cat((fg_embs[init_fg_meta[0]], torch.tensor([0, 0, init_fg_meta[4], init_fg_meta[5]/360, init_fg_meta[6]])))
fg_emb_0 = fg_emb_0.unsqueeze(dim=0).unsqueeze(dim=0).to(device)

bk_emb = bg_embs[init_bg]
h_0 = torch.stack([bk_emb for _ in range(args['num_layers'])]).unsqueeze(dim=1).to(device)
c_0 = torch.zeros_like(h_0).to(device)

print(f"h shape:{h_0.shape}; fg_emb.shape:{fg_emb_0.shape}; bk_emb.shape:{bk_emb.shape}")

Define metrics for comparing foreground embeddings.

In [None]:
def mse(A, B):
    return ((A - B)**2).mean()

def find_closest(emb, emb_dict):
    diffs = {}
    for k, v in emb_dict.items():
        diffs[k] = mse(emb, v)
    min_k = min(diffs, key=diffs.get)
    return min_k

Generate a sequence of foreground arts

In [None]:
h, c = h_0, c_0
fg_emb = fg_emb_0
i = 1
fgs = [init_fg_meta]
with torch.no_grad():
    for i in range(1, fg_seq_len):
    # when i is 1, use intact model
        if i == 1:
            seqs_logits, cls_logits, (h, c) = seq_model(fg_emb, (h, c), return_last_hidden=True)
        # when i is not 1, use parts of the model to directly feed h and c to rnn
        else:
            output_seqs, (h, c) = seq_model.rnn_model(fg_emb, (h, c))
            seqs_logits = seq_model.seq_transformer(output_seqs)
        # process sequence logits to fit corresponding value scales
        seqs_logits[:, :, -5:-2] = torch.sigmoid(seqs_logits[:, :, -5:-2])
        seqs_logits[:, :, -2:] = torch.tanh(seqs_logits[:, :, -2:])
        fg_emb = seqs_logits[0,0,:576] # next fg embedding
        fg_meta = seqs_logits[0,0,576:] # (x, y, scale, angle, opaque)
        fg_name = find_closest(fg_emb.detach().cpu(), fg_embs)
        fg_emb = fg_embs[fg_name].to(device)
        fg_emb = torch.cat((fg_emb, fg_meta))
        fg_emb = fg_emb.unsqueeze(dim=0).unsqueeze(dim=0).to(device)
        # again, magic numbers come from inspection of data
        fgs.append([fg_name, i, fgs[-1][2]+fg_meta[0].item()*402, fgs[-1][3]+fg_meta[1].item()*600] +  [it.item() for it in fg_meta[2:]])

In [None]:
fgs

Define image generation related functions

In [None]:
def rotate_image(mat, angle):
    """
    Rotates an image (angle in degrees) and expands image to avoid cropping
    """

    height, width = mat.shape[:2] # image shape has 3 dimensions
    image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape

    rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)

    # rotation calculates the cos and sin, taking absolutes of those.
    abs_cos = abs(rotation_mat[0,0]) 
    abs_sin = abs(rotation_mat[0,1])

    # find the new width and height bounds
    bound_w = int(height * abs_sin + width * abs_cos)
    bound_h = int(height * abs_cos + width * abs_sin)

    # subtract old image center (bringing image back to origo) and adding the new image center coordinates
    rotation_mat[0, 2] += bound_w/2 - image_center[0]
    rotation_mat[1, 2] += bound_h/2 - image_center[1]

    # rotate image with the new bounds and translated rotation matrix
    rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
    return rotated_mat

def scale_image(mat, scale):
    scaled_mat = cv2.resize(mat,None,fx=scale, fy=scale, interpolation = cv2.INTER_CUBIC)
    return scaled_mat

def transform_image(mat, scale, angle):
    '''
    scale and rotate image in one step
    '''
    scaled_mat = scale_image(mat, scale)
    rotated_mat = vanilla_rotate(scaled_mat, angle)
    return rotated_mat

def overlay_transparent(background, overlay, x, y):

    background_width = background.shape[1]
    background_height = background.shape[0]

    if x >= background_width or y >= background_height:
        return background

    h, w = overlay.shape[0], overlay.shape[1]

    if x + w > background_width:
        w = background_width - x
        overlay = overlay[:, :w]

    if y + h > background_height:
        h = background_height - y
        overlay = overlay[:h]

    if overlay.shape[2] < 4:
        overlay = np.concatenate(
            [
                overlay,
                np.ones((overlay.shape[0], overlay.shape[1], 1), dtype = overlay.dtype) * 255
            ],
            axis = 2,
        )

    overlay_image = overlay[..., :4]
    mask = overlay[..., 3:] / 255.0

    background[y:y+h, x:x+w] = (1.0 - mask) * background[y:y+h, x:x+w] + mask * overlay_image

    return background