In [None]:
import os
import json
import random
import re

In [None]:
## S3 Access

In [None]:
import boto3
from sagemaker import get_execution_role

In [None]:
role = get_execution_role()
bucket='devopstar'
data_key = 'resources/fbmsg-analysis-gpt-2/facebook.zip'

s3 = boto3.resource('s3')
with open('facebook.zip', 'wb') as data:
    s3.Bucket(bucket).download_fileobj(data_key, data)

In [None]:
!unzip facebook.zip

In [None]:
## Download Dependencies

In [None]:
!pip install --upgrade pip
!pip install -r requirements.txt

In [None]:
## Download Model

In [None]:
!sh download_model.sh 117M

In [None]:
## Get List of files

In [None]:
files = []
for p, d, f in os.walk('messages/inbox'):
    for file in f:
        if file.endswith('message.json'):
            files.append(f'{p}/{file}')

len(files)

In [None]:
## Helper Functions

In [None]:
def fix_encoding(s):
    return re.sub('[\xc2-\xf4][\x80-\xbf]+',lambda m: m.group(0).encode('latin1').decode('utf8'),s)

def find_cyrilic(s):
    return len(re.findall('(?i)[А-ЯЁ]', s)) > 0

def test_mostly_cyrilic(messages):
    i = 0
    check_n = min(250, len(messages))
    for msg in random.sample(messages, check_n):
        try:
            i +=find_cyrilic(fix_encoding(msg['content'])) or find_cyrilic(fix_encoding(msg['sender_name']))
        except KeyError:
            check_n -=1
    return i > check_n/5

In [None]:
## Load Messages

In [None]:
### All Names

In [None]:
def create_file(files=files):
    text_corpus = ''
    banned_names = ()
    for file in files:
        with open(file, 'r') as f:
            try:
                msgs = json.load(f)['messages']
                msgs.reverse()
            except:
                pass
            else:
                if not test_mostly_cyrilic(msgs) and not any(bn in file for bn in banned_names):
                    for msg in msgs:
                        try:
                            content = fix_encoding(msg['content'])
                            to_add  = f"({msg['timestamp_ms']}) {msg['sender_name']}: {content}\n"
                            if not find_cyrilic(to_add):
                                text_corpus += to_add
                        except KeyError:
                            pass
                    print(file)

    text_corpus += '\n\n'
    with open('fb-cleaned.txt', 'w') as f:
          f.write(text_corpus)

In [None]:
### Specify Particular Person

In [None]:
def create_specific_file(person, files=files):
    text_corpus = ''
    for file in files:
        if person in file:
            print(file)
            with open(file, 'r') as f:
                try:
                    msgs = json.load(f)['messages']
                    msgs.reverse()
                except:
                    pass
                else:
                    for msg in msgs:
                        try:
                            content = fix_encoding(msg['content'])
                            to_add  = f"({msg['timestamp_ms']}) {msg['sender_name']}: {content}\n"
                            if not find_cyrilic(to_add):
                                text_corpus += to_add
                        except KeyError:
                            pass

    text_corpus += '\n\n'
    with open(f'fb-cleaned-{person}.txt', 'w') as f:
        f.write(text_corpus)
        return

In [None]:
### Run

In [None]:
create_file(files)

In [None]:
## Train

In [None]:
!PYTHONPATH=src ./encode.py --in-text fb-cleaned.txt --out-npz fb-cleaned.txt.npz
!PYTHONPATH=src ./train.py --dataset fb-cleaned.txt.npz --sample_every=250 --learning_rate=0.0001 --stop_after=251

In [None]:
## Run

In [None]:
mv checkpoint/run1/* models/117M/

In [None]:
!python3 src/generate_unconditional_samples.py --top_k 40 --temperature 0.9