In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [None]:
# !pip install update torch --extra-index-url https://download.pytorch.org/whl/cu113
!pip uninstall torch -y
!pip uninstall torchvision -y
!pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader,random_split
import gluonnlp as nlp
import numpy as np
from tqdm.notebook import tqdm
from kobert import get_tokenizer
from kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from PIL import Image
import torchvision.transforms as T
import pandas as pd
import torchvision.utils as vutils

device = 'cuda'

In [None]:
class BERTBcdata(Dataset):
    def __init__(self, csv_path, img_path, tok, max_seq_length, img_shape=64):
        self.transform = nlp.data.BERTSentenceTransform(tok, max_seq_length=max_seq_length, pad=True, pair=False)
        self.df = pd.read_csv(csv_path).dropna()
        self.img_path = img_path
        self.img_shape = img_shape
        self.transpose = T.Compose([
            T.Resize(img_shape+8),
            T.CenterCrop(img_shape),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

    def __getitem__(self, i):
        sentences = self.transform([str(self.df.iloc[i]['comment'])])
        label = int(self.df.iloc[i]['Emotion'])
        img_name = '{}/{}.jpg'.format(self.img_path, label)
        img = Image.open(img_name)
        label = self.transpose(img)

        return sentences[0], sentences[1], sentences[2], label

    def __len__(self):
        return (len(self.df))


class BERTBcTestdata(Dataset):
    def __init__(self, csv_path, img_path, tok, max_seq_length, img_shape=64):
        self.transform = nlp.data.BERTSentenceTransform(tok, max_seq_length=max_seq_length, pad=True, pair=False)
        self.df = pd.read_csv(csv_path).dropna()
       

    def __getitem__(self, i):
        sentences = self.transform([str(self.df.iloc[i]['comment'])])
        return sentences[0], sentences[1], sentences[2]

    def __len__(self):
        return (len(self.df))


In [None]:
bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

In [None]:
dataset = BERTBcdata('/content/drive/Shareddrives/선형대수학/주제분석/emoji/취합1.csv',
           '/content/drive/Shareddrives/선형대수학/주제분석/emoji',
           tok,
           100)
len(dataset)

In [None]:
data_size = len(dataset)
val_size = 16

train_size = data_size - val_size
train_dataset, val_dataset =  random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
len(train_dataset)

In [None]:
train_loader = DataLoader(train_dataset, batch_size = 8, shuffle = True)
test_loader = DataLoader(val_dataset, batch_size = 16)

for data in train_loader:
    print(data[0].shape) 
    print(data[1].shape) 
    print(data[2].shape) 
    print(data[3].shape)
    
    break

In [None]:
class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=8):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)    
        self.lkr1 = nn.LeakyReLU(0.2, inplace=True)
        # state size. (ndf) x 32 x 32       3 x 32 x 32
        self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)
        self.bm2 = nn.BatchNorm2d(ndf * 2)
        self.lkr2 = nn.LeakyReLU(0.2, inplace=True)
        # state size. (ndf*2) x 16 x 16
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)
        self.bm3 = nn.BatchNorm2d(ndf * 4)
        self.lkr3 = nn.LeakyReLU(0.2, inplace=True)
        # state size. (ndf*4) x 8 x 8
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)
        self.bm4 = nn.BatchNorm2d(ndf * 8)
        self.lkr4 = nn.LeakyReLU(0.2, inplace=True)
        # state size. (ndf*8) x 4 x 4
        self.conv5 = nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
        self.sig = nn.Sigmoid()

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.conv1(x)
        x = self.lkr1(x)
        x = self.conv2(x)
        x = self.bm2(x)
        x = self.lkr2(x)
        x = self.conv3(x)
        x = self.bm3(x)
        x = self.lkr3(x)
        x = self.conv4(x)
        x = self.bm4(x)
        x = self.lkr4(x)

        # x = self.dropout(x)
        
        x = self.conv5(x)
        x = self.sig(x)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [None]:
class BGenerator(nn.Module):
    def __init__(self, bertmodel, bert_dim=768, ngf=8, nc=3, device='cuda',freezing_layer_num = 11):
        super(BGenerator, self).__init__()
        self.bertmodel = bertmodel.to(device)
        for n,p in bertmodel.named_parameters():
            print(n)
            if 'embeddings' in n:
                p.requires_grad= False

            if 'encoder.layer' in n:
                layer_num = n.split(sep ='.')[2] 
                if int(layer_num) <= freezing_layer_num:
                    p.requires_grad = False
        self.generator = Generator(bert_dim, ngf, nc).to(device)

        self.bertmodel.named_parameters()

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        _, cls = bertmodel(input_ids=token_ids, token_type_ids=segment_ids.long(),
                        attention_mask=attention_mask.float().to(token_ids.device))
        
        x = self.generator(cls.view(-1,768,1,1))

        return x

    

In [None]:
bg = BGenerator(bertmodel, device = device)
d = Discriminator().to(device)
# bg.load_state_dict(torch.load('path'))
# d.load_state_dict(torch.load('path'))
criterion = nn.BCELoss()
optimizerD = optim.Adam(d.parameters(), lr=0.0002)
optimizerG = optim.Adam(bg.parameters(),lr = 0.0002)

In [None]:
img_list = []
G_losses = []
D_losses = []
img_list1 = []

for e in range(3000):

    trainD_loss = 0
    trainG_loss = 0
    bg.train()
    d.train()
    for data in train_loader:
        
        d.zero_grad()
        
        fake_img = bg(data[0].to(device),data[1].to(device),data[2].to(device)) 

        fake_result = d(fake_img).view(-1)  
        label = torch.zeros(fake_result.size()).to(device)
        fakeD_loss = criterion(fake_result, label)  


        true_result = d(data[3].to(device)).view(-1)    
        true_label = torch.ones(true_result.size()).to(device)  
        trueD_loss = criterion(true_result, true_label) 
        

        fakeD_loss.backward(retain_graph=True)
        trueD_loss.backward(retain_graph=True)

        optimizerD.step()

        trainD_loss += fakeD_loss.item() 
        trainD_loss += trueD_loss.item()    
    
        bg.zero_grad()

        fake_result = d(fake_img).view(-1)  
        fake_label = torch.ones(fake_result.size()).to(device)
        fakeG_loss = criterion(fake_result, fake_label)
        fakeG_loss.backward()

        optimizerG.step()

        trainG_loss += fakeG_loss.item()

        G_losses.append(trainD_loss)
        D_losses.append(trainG_loss)


    print('{}epoch trainD_loss:{}'.format(e+1, trainD_loss/len(train_loader)))
    print('{}epoch trainG_loss:{}'.format(e+1, trainG_loss/len(train_loader)))
    bg.eval()

    if (e+1)%10 == 0:
        for data in test_loader:
            fake_img = bg(data[0].to(device), data[1].to(device), data[2].to(device)).detach().cpu()
            img_list1.append(fake_img)
            img_list.append(vutils.make_grid(fake_img, padding=2, normalize=True))
        torch.save(bg.state_dict(),'./bg_{}.pt'.format(e+1))    
        torch.save(d.state_dict(),'./d_{}.pt'.format(e+1))   
    

In [None]:
import matplotlib.pyplot as plt

row = len(img_list1)
col = len(img_list1[0])
fig, axes = plt.subplots( row, col, figsize = (5*col,5*row))
print(row)
print(col)
for i in range(row):
    for j in range(col):
        img = np.transpose(img_list1[i][j], (1,2,0))
        axes[i][j].imshow(img)


In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
!pip install imagemagick

In [None]:
anim.save('sine_wave_interval_100ms.gif', writer='imagemagick')