# These are snippets of code that support us throughout this project

In [None]:
import pandas as pd
import cv2
import os
import csv

In [None]:
# We want to run analysis on our generated pairs
def analyze_csv(filename):
    df = pd.read_csv(filename)
    ones = df['label'].eq(1).sum()
    zeros = df['label'].eq(0).sum()
    return ones, zeros


In [None]:
def generate_style(img_dir, csv_file, out_dir):
    df = pd.read_csv(csv_file)
    os.chdir(img_dir)
    for index, row in df.iterrows():
        print(index)
        file = row['filename']
        img = cv2.imread(file)
        cv2.imwrite(os.path.join(out_dir,file),img)

In [None]:
analyze_csv('train_120k_pair.csv')

In [None]:
analyze_csv('test_30k_pair.csv')

In [None]:
path = '/home/tannhat_ng/'
csv = os.path.join(path, 'test_art_romanticism.csv')
train_reg = os.path.join(path, 'test_reg')
train_roman = os.path.join(path, 'test_roman')
generate_style(train_reg, csv, train_roman)

In [None]:
SIZE = 10000
def generate_dataset(filepath, output_filepath):
    df = pd.read_csv(filepath)
    column = ['file1', 'file2', 'artist1', 'artist2', 'label']
    output = pd.DataFrame([], columns=column)
    groups = df.groupby(['artist'])
    
    for i in range(SIZE):
        #print(i)
        same = random.randint(0, 1)
        row1 = df.sample()
        filename1 = row1['filename'].item()
        artist = row1['artist'].item()
        group = groups.get_group(artist)
        if len(group.index) == 1:
            same = 0
        if same:
            # Generate paintings with the same painters
            row2_options = df.loc[df['artist'] == artist]
            rand = random.randint(0, row2_options.shape[0]-1)
            row2 = row2_options.iloc[[rand]]
            filename2 = row2['filename'].item()
            while filename2 == filename1:
              row2_options = df.loc[df['artist'] == artist]
              rand = random.randint(0, row2_options.shape[0]-1)
              row2 = row2_options.iloc[[rand]]
              filename2 = row2['filename'].item()
            new_row = pd.DataFrame([[filename1, filename2, artist, artist, 1]], columns=column)
            output = output.append(new_row)
            
        else:
            row2 = df.sample()
            artist2 = row2['artist'].item()
            while artist2 == artist:
                row2 = df.sample()
                artist2 = row2['artist'].item()
            filename2 = row2['filename'].item()
            new_row = pd.DataFrame([[filename1, filename2, artist, artist2, 0]], columns=column)
            output = output.append(new_row)

    output.to_csv(output_filepath, columns=column, index=False)

In [None]:
# Generate csv for particular folder
def generate_sub_csv(filepath, output_filepath):
    new_csv = []
    column = None
    with open('/content/drive/My Drive/cs482/Painter Data/train_reg_info.csv') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        first_line = True
        for row in csv_reader:
            if first_line:
                first_line = False
                column = row
            else:
                if os.path.exists(filepath + row[0]):
                    new_csv.append(row)

    with open(output_filepath, 'w') as csv_file:
        writer = csv.writer(csv_file, delimiter=',')
        writer.writerow(column)
        for row in new_csv:
           writer.writerow(row)

In [None]:
# Visualization
painter_dataset = PainterDataset(csv_file = os.path.join(data_dir,'test_info.csv'),root_dir = os.path.join(data_dir,'test_reg'))

fig = plt.figure()

for i in range(len(painter_dataset)):
    sample = painter_dataset[i+3]

    ax = plt.subplot(2, 3, i + 1)
    plt.tight_layout()
    ax.set_title('{}'.format(sample["style"]))
    ax.axis('off')
    plt.imshow(sample["image"])

    if i == 5:
        plt.show()
        break

#plt.imshow(painter_dataset[0]["image"])
#print(painter_dataset[0]["style"])

In [None]:
class FaceDataset(Dataset):
    def __init__(self, folder, transform):
        self.folder = folder
        self.transform = transform
        
    def __getitem__(self, index):
        # Random choice gives out an actual image and a "label" as a tuple
        # Can compare the "label" to 
        img1_wrapper = random.choice(self.folder.imgs)
        
        # Choose either from the same class or different class
        same = random.randint(0,1)
        if same:
            # Loop until find the same class
            while True:
                img2_wrapper = random.choice(self.folder.imgs)
                if img1_wrapper[1] == img2_wrapper[1]:
                    break
        else:
            # Loop until find a different class
            while True:
                img2_wrapper = random.choice(self.folder.imgs)
                if img1_wrapper[1] != img2_wrapper[1]:
                    break
        
        # Get the actual image and convert it to grayscale
        img1 = Image.open(img1_wrapper[0]).convert('L')
        img2 = Image.open(img2_wrapper[0]).convert('L')
        
        # Implement transform
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        # Get the label
        # Label 1 means same values
        # Label 0 means different values
        label = np.array([img1_wrapper[1] == img2_wrapper[1]], dtype=np.float64)
        label = torch.from_numpy(label)
        
        return img1, img2, label
        
    def __len__(self):
        return len(self.folder.imgs)

In [None]:
class ContrastiveLoss(nn.Module):
    '''
    Implement Contrastive Loss function
    Resources: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    '''
    
    def __init__(self, margin=1.0, size_average=True, reduce=True):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.reduce = reduce
        self.size_average = size_average
        
        
    def forward(self, distance, target):
#         if distance < self.margin:
#             raise ValueError('Margin needs to be larger than the Gaussian distance')
        loss = 0.5 * ((1-target) * torch.pow(distance, 2) + 
                     target * torch.pow(F.relu(self.margin - distance), 2))
        if self.reduce:
            if self.size_average:
                loss = torch.mean(loss)
            else:
                loss = torch.sum(loss)
        return loss