## Recap
### The String Manipulation Task

The task is to synthesize string-manipulating programs that use a combination of the `split`, `join` and `indexing` functionalities of Python that transform an input string to the desired output string. As an example, suppose a user wants to write a program to extract the first name from the full name of a person. The full name contains the first name and the last name separated by a space i.e. ` `. The following Python program solves the task. Here `inp` is the variable containing the input string.

```python
inp.split(' ')[0]
```

## Writing a Synthesis for a Language containing `split`, `join` and `indexing`

As mentioned in the previous tutorial, we will now write a generator for the full mini-language we introduced. In particular, we will write a generator that will produce all programs using these primitives. First let us formally define the language.

Let `S` be a program that takes an input string or list of strings (as in `join`) and produces an output string. Let `L` be a program that takes an input string and produces an output list of strings (as in `split`). Then our programs of interest can be succinctly expressed using the following grammar -

```
S ::= inp | substring of input | substring of output | L[i] | S.join(L)
L ::= S.split(S)
```

Here `L[i]` is the indexing operation. Notice we include substring extraction of the input/output. This is to express the constant splitting/join strings often used with `split` and `join`.

## 1. Writing the Generator

Here are two top-level generators for `S` and `L`. We use a `depth` parameter to bound the nesting of function calls. Along with the program, we simultaneously compute the result of the program as well for convenience.

In [1]:
from atlas import generator
from atlas.exceptions import ExceptionAsContinue


@generator
def gen_str_program(inp, out, depth=0, max_depth=3):
    if depth >= max_depth:
        raise ExceptionAsContinue
        
    operation = Select(["Inp", "ConstInp", "ConstOut", "Index", "Join"], context=(inp, out), uid="1")
    if operation == "Inp":
        return inp, "inp"
    
    if operation == "ConstInp":
        substr = Substr(inp, context=(inp, out))
        return substr, f"'{substr}'"
    
    if operation == "ConstOut":
        substr = Substr(out, context=(inp, out))
        return substr, f"'{substr}'"
    
    if operation == "Index":
        #  Generate a program that produces a list of strings
        strlist, list_prog = gen_strlist_program(inp, out, depth=depth+1, max_depth=max_depth)
        index = Select(range(len(strlist)))  # Use the result of the list-program to bound the number indices!
        return strlist[index], f"{list_prog}[{index}]"
    
    if operation == "Join":
        return gen_join_programs(inp, out, depth=depth, max_depth=max_depth)
        
        
@generator
def gen_strlist_program(inp, out, depth=0, max_depth=3):
    tosplit_str, tosplit_prog = gen_str_program(inp, out, depth=depth + 1, max_depth=max_depth)
    sep_str, sep_prog = gen_str_program(inp, out, depth=depth + 1, max_depth=max_depth)
    
    return tosplit_str.split(sep_str), f"{tosplit_prog}.split({sep_prog})"

#### Exercise: Finish the `gen_join_programs` Generator

In [2]:
@generator
def gen_join_programs(inp, out, depth=0, max_depth=3):
    join_str, str_prog = gen_str_program(inp, out, depth=depth + 1, max_depth=max_depth)
    #  FILL IN THE RHS HERE
    strlist, list_prog = gen_strlist_program(inp, out, depth=depth + 1, max_depth=max_depth)
    
    return join_str.join(strlist), f"{str_prog}.join({list_prog})"

Let's check out some of the programs it generates!

In [3]:
for res, prog in gen_str_program.generate("Alan Turing", "Alan").with_strategy('randomized').first(k=5):
    print("Program :", prog)

Program : inp
Program : inp
Program : inp.split('a')[1]
Program : 'lan '
Program : 'la'


## 2. Defining the Model

The real drivers of this generator are the top-level Select (choosing the operation) and the substring operators. So we'll define a model for them.

TODO : Model Description Diagram

In [4]:
from atlas.operators import operator
from atlas.models.pytorch.imitation import PyTorchGeneratorSharedStateModel
from string_models import SubstrModel, SelectFuncModel

class MasterModel(PyTorchGeneratorSharedStateModel):
    @operator(name='Substr')
    def substr_model(self, *args, **kwargs):
        return SubstrModel(node_dim=10)

    @operator(name='Select', uid="1")
    def select_func_model(self, *args, **kwargs):
        return SelectFuncModel(node_dim=10, num_classes=5)

## 3. Task-1 Training

For the purpose of this tutorial, we'll train the generator to specialize towards two tasks -

(1) Extraction via splitting (`split`)

(2) Separator replacement (`split` followed by `join`)

Let's load some training data for that.

In [5]:
import pickle
from urllib.request import urlopen
data = pickle.load(urlopen("https://risecamp2019-atlas.s3.us-east-2.amazonaws.com/string_dataset_medium.pkl"))

print(data[0])
print(data[1])


        GeneratorTrace(inputs=(('Emma.Aaron', 'Emma__Aaron'), {}),
                       op_traces=[
OpTrace(op_info=OpInfo(sid='/gen_str_program/Select@1@1', gen_name='gen_str_program', op_type='Select', index=1, gen_group=None, uid='1', tags=None),
        choice='Join',
        domain=['Inp', 'ConstInp', 'ConstOut', 'Index', 'Join'],
        context=('Emma.Aaron', 'Emma__Aaron'),
        **{}
       ), 
OpTrace(op_info=OpInfo(sid='/gen_str_program/Select@1@1', gen_name='gen_str_program', op_type='Select', index=1, gen_group=None, uid='1', tags=None),
        choice='ConstOut',
        domain=['Inp', 'ConstInp', 'ConstOut', 'Index', 'Join'],
        context=('Emma.Aaron', 'Emma__Aaron'),
        **{}
       ), 
OpTrace(op_info=OpInfo(sid='/gen_str_program/Substr@@2', gen_name='gen_str_program', op_type='Substr', index=2, gen_group=None, uid=None, tags=None),
        choice='__',
        domain='Emma__Aaron',
        context=('Emma.Aaron', 'Emma__Aaron'),
        **{}
       ), 
OpT

Let's train!

In [17]:
model = MasterModel(state_dim=10, learning_rate=0.01)
train = data[:200]
valid = data[200:]
model.train(train, valid, num_epochs=10)

[Epoch 0] Training Loss: 3.3169 Training Acc: 0.01
[Epoch 0] Validation Loss: 2.4109 Validation Acc: 0.04
[Epoch 1] Training Loss: 1.5975 Training Acc: 0.08
[Epoch 1] Validation Loss: 1.0876 Validation Acc: 0.17
[Epoch 2] Training Loss: 0.8878 Training Acc: 0.17
[Epoch 2] Validation Loss: 0.6487 Validation Acc: 0.17
[Epoch 3] Training Loss: 0.6299 Training Acc: 0.19
[Epoch 3] Validation Loss: 0.9274 Validation Acc: 0.22
[Epoch 4] Training Loss: 0.5846 Training Acc: 0.24
[Epoch 4] Validation Loss: 0.4336 Validation Acc: 0.39
[Epoch 5] Training Loss: 0.4931 Training Acc: 0.47
[Epoch 5] Validation Loss: 0.6275 Validation Acc: 0.54
[Epoch 6] Training Loss: 0.4281 Training Acc: 0.54
[Epoch 6] Validation Loss: 0.3361 Validation Acc: 0.55
[Epoch 7] Training Loss: 0.3950 Training Acc: 0.54
[Epoch 7] Validation Loss: 0.4137 Validation Acc: 0.56
[Epoch 8] Training Loss: 0.3694 Training Acc: 0.54
[Epoch 8] Validation Loss: 0.3010 Validation Acc: 0.54
[Epoch 9] Training Loss: 0.3305 Training Acc: 

In [20]:
def synthesize(inp, out):
    for res, prog in gen_str_program.generate(inp, out).with_model(model).first(k=500):
        if res == out:
            print("Solution Found :", prog)

In [22]:
synthesize("Alan Mathison Turing", "Alan.Mathison.Turing")

Solution Found : '.'.join(inp.split(' '))
Solution Found : '.'.join(inp.split(' '))


Great! Feel free to try input-outputs of your own. Be aware that this generator has only been trained on two tasks, and that too on limited amount of data, so it may not work quite as well.

## Full Training

We've already trained the generator on large amounts of data containing many tasks. You can test out the synthesizer here.

## TODO