This notebook formats the train and test data as into a "prompt"/"completion" format, as required by GPT-3 (https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset). Prompt has format "artist;topic_id"

### Set-up

In [None]:
# mount google drive
from google.colab import drive
import os

drive.mount('/content/drive/')
os.chdir('/content/drive/Shareddrives/CS260-Project/data/')

Mounted at /content/drive/


In [None]:
import csv
import random

### Reformat training and test sets

In [None]:
# format training set
with open('../data/train/big-lda-train-40.csv', 'r') as orig_data:
  reader = csv.reader(orig_data, delimiter = ',')
  line_in = 0
  with open('../data/train/big-lda-train-40-formatted.csv', 'w') as formatted_data:
      writer = csv.writer(formatted_data, delimiter=',')
      writer.writerow(['prompt', 'completion'])

      for row in reader:
        if line_in > 1:
          # uses recommended separators and formatting
          prompt = row[0] + ";" + row[1] + "\n\n###\n\n"
          completion = " " + row[2] + "###"
          writer.writerow([prompt, completion])
        line_in += 1
    

In [None]:
# format test set
with open('../data/test/big-lda-test-40.csv', 'r') as orig_data:
  reader = csv.reader(orig_data, delimiter = ',')
  line_in = 0
  with open('../data/test/big-lda-test-40-formatted.csv', 'w') as formatted_data:
      writer = csv.writer(formatted_data, delimiter=',')
      writer.writerow(['prompt', 'completion'])

      for row in reader:
        if line_in > 1:
          # uses recommended separators and formatting
          prompt = row[0] + ";" + row[1] + "\n\n###\n\n"
          completion = " " + row[2] + "###"
          writer.writerow([prompt, completion])
        line_in += 1

### For GPT-3 fine-tunes, sample smaller training and validation sets

In [None]:
# sample 1000 rows in training set, use next 100 to create validation set
with open('../data/train/big-lda-train-40-formatted.csv', 'r') as orig_data:
  reader = csv.reader(orig_data, delimiter = ',')
  with open('../data/train/40-topic-sample-1000-train.csv', 'w') as formatted_data:
    with open('../data/val/40-topic-sample-1000-val.csv', 'w') as val_set:
      writer = csv.writer(formatted_data, delimiter=',')
      writer.writerow(['prompt', 'completion'])

      writer_val = csv.writer(val_set, delimiter=',')

      line = 0
      for row in reader:
        if line > 0 and line < 1001:
          writer.writerow(row)
          pass
        elif line < 1101:
          writer_val.writerow(row)
        line += 1


In [None]:
# Remove validation set from original data (write to new file)
with open('../data/train/lda-train-6-formatted.csv', 'r') as formatted_data:
  reader = csv.reader(formatted_data, delimiter = ',')
  with open('../data/train/lda-train-6-clean.csv', 'w') as clean:
    writer = csv.writer(clean, delimiter=',')
    writer.writerow(['prompt', 'completion'])

    line = 0
    for row in reader:
      if line > 1 and not (line > 1002 and line < 1102):
        writer.writerow(row)
      line += 1

In [None]:
# For each topic, sample 1000 rows in training set, use next 100 to create validation set
for topic in range(10):
  with open('../data/train/lda-train-6-formatted.csv', 'r') as orig_data:
    reader = csv.reader(orig_data, delimiter = ',')
    train_path = '../data/train/' + str(topic) + '-sample-100-train.csv'
    val_path = '../data/val/' + str(topic) + '-sample-100-val.csv'
    with open(train_path, 'w') as formatted_data:
      with open(val_path, 'w') as val_set:
        writer = csv.writer(formatted_data, delimiter=',')
        writer.writerow(['prompt', 'completion'])

        writer_val = csv.writer(val_set, delimiter=',')
        writer_val.writerow(['prompt', 'completion'])

        line = 0
        count = 0
        for row in reader:
          if line > 0:
            prompt = row[0]
            if int(prompt.split(';')[1][0]) == topic:
              if count < 100:
                writer.writerow(row)
              elif count < 110:
                writer_val.writerow(row)
              count += 1
          line += 1